Skip to content

Commit 752584d

Browse files
ngxsonCISC
andauthored
model: support GLM MoE DSA arch (NOTE: indexer is not yet supported) (ggml-org#19460)
* model: support GLM MoE DSA arch * working version * pyright * keep indexer tensors * add indexer gguf params * loaded now * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * update * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * minor fix and cleanup --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent cc2aa81 commit 752584d

10 files changed

Lines changed: 361 additions & 41 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,23 @@ def _set_vocab_glmedge(self):
16081608
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"])
16091609
special_vocab.add_to_gguf(self.gguf_writer)
16101610

1611+
def _set_vocab_glm(self):
1612+
from transformers import AutoTokenizer
1613+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
1614+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
1615+
tokens, toktypes, tokpre = self.get_vocab_base()
1616+
self.gguf_writer.add_tokenizer_model("gpt2")
1617+
self.gguf_writer.add_tokenizer_pre(tokpre)
1618+
self.gguf_writer.add_token_list(tokens)
1619+
self.gguf_writer.add_token_types(toktypes)
1620+
# Special tokens
1621+
# Note: Using <|endoftext|> (151329) for eot causes endless generation
1622+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
1623+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
1624+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
1625+
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
1626+
special_vocab.add_to_gguf(self.gguf_writer)
1627+
16111628
def _set_vocab_interns1(self):
16121629
tokens: list[str] = []
16131630
toktypes: list[int] = []
@@ -7710,6 +7727,9 @@ def prepare_tensors(self):
77107727
class DeepseekV2Model(TextModel):
77117728
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
77127729

7730+
# TODO @ngxson : remove this when we support MTP for deepseek models
7731+
skip_mtp = True
7732+
77137733
def set_vocab(self):
77147734
try:
77157735
self._set_vocab_gpt2()
@@ -7841,10 +7861,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
78417861
name = name.replace("e_score_correction_bias", "e_score_correction.bias")
78427862

78437863
# skip Multi-Token Prediction (MTP) layers
7844-
block_count = self.hparams["num_hidden_layers"]
7845-
match = re.match(r"model.layers.(\d+)", name)
7846-
if match and int(match.group(1)) >= block_count:
7847-
return
7864+
if self.skip_mtp:
7865+
block_count = self.hparams["num_hidden_layers"]
7866+
match = re.match(r"model.layers.(\d+)", name)
7867+
if match and int(match.group(1)) >= block_count:
7868+
return
78487869

78497870
# process the experts separately
78507871
if name.find("mlp.experts") != -1:
@@ -8684,24 +8705,7 @@ def __init__(self, *args, **kwargs):
86848705
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
86858706

86868707
def set_vocab(self):
8687-
from transformers import AutoTokenizer
8688-
8689-
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
8690-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
8691-
tokens, toktypes, tokpre = self.get_vocab_base()
8692-
self.gguf_writer.add_tokenizer_model("gpt2")
8693-
self.gguf_writer.add_tokenizer_pre(tokpre)
8694-
self.gguf_writer.add_token_list(tokens)
8695-
self.gguf_writer.add_token_types(toktypes)
8696-
8697-
# Special tokens
8698-
# Note: Using <|endoftext|> (151329) for eot causes endless generation
8699-
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
8700-
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
8701-
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
8702-
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
8703-
8704-
special_vocab.add_to_gguf(self.gguf_writer)
8708+
return self._set_vocab_glm()
87058709

87068710
def set_gguf_parameters(self):
87078711
super().set_gguf_parameters()
@@ -8801,26 +8805,38 @@ def prepare_tensors(self):
88018805
class Glm4MoeLiteModel(DeepseekV2Model):
88028806
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
88038807

8804-
# copied from Glm4MoeModel
88058808
def set_vocab(self):
8806-
from transformers import AutoTokenizer
8809+
return self._set_vocab_glm()
88078810

8808-
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
8809-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
8810-
tokens, toktypes, tokpre = self.get_vocab_base()
8811-
self.gguf_writer.add_tokenizer_model("gpt2")
8812-
self.gguf_writer.add_tokenizer_pre(tokpre)
8813-
self.gguf_writer.add_token_list(tokens)
8814-
self.gguf_writer.add_token_types(toktypes)
88158811

8816-
# Special tokens
8817-
# Note: Using <|endoftext|> (151329) for eot causes endless generation
8818-
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331
8819-
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336
8820-
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329
8821-
special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338
8812+
@ModelBase.register("GlmMoeDsaForCausalLM")
8813+
class GlmMoeDsaModel(DeepseekV2Model):
8814+
model_arch = gguf.MODEL_ARCH.GLM_DSA
8815+
skip_mtp = False
88228816

8823-
special_vocab.add_to_gguf(self.gguf_writer)
8817+
def __init__(self, *args, **kwargs):
8818+
super().__init__(*args, **kwargs)
8819+
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0)
8820+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
8821+
8822+
def set_vocab(self):
8823+
return self._set_vocab_glm()
8824+
8825+
def set_gguf_parameters(self):
8826+
super().set_gguf_parameters()
8827+
8828+
rope_dim = self.hparams["qk_rope_head_dim"]
8829+
partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0)
8830+
self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor))
8831+
8832+
# NextN/MTP prediction layers
8833+
if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None:
8834+
self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers)
8835+
8836+
# DSA indexer parameters
8837+
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
8838+
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
8839+
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
88248840

88258841

88268842
@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")

gguf-py/gguf/constants.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,11 @@ class Attention:
181181
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
182182
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
183183

184+
class Indexer:
185+
HEAD_COUNT = "{arch}.attention.indexer.head_count"
186+
KEY_LENGTH = "{arch}.attention.indexer.key_length"
187+
TOP_K = "{arch}.attention.indexer.top_k"
188+
184189
class Rope:
185190
DIMENSION_COUNT = "{arch}.rope.dimension_count"
186191
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
@@ -425,6 +430,7 @@ class MODEL_ARCH(IntEnum):
425430
CHATGLM = auto()
426431
GLM4 = auto()
427432
GLM4_MOE = auto()
433+
GLM_DSA = auto()
428434
BITNET = auto()
429435
T5 = auto()
430436
T5ENCODER = auto()
@@ -670,6 +676,10 @@ class MODEL_TENSOR(IntEnum):
670676
VISEXP_GATE = auto()
671677
VISEXP_DOWN = auto()
672678
VISEXP_UP = auto()
679+
INDEXER_K_NORM = auto()
680+
INDEXER_PROJ = auto()
681+
INDEXER_ATTN_K = auto()
682+
INDEXER_ATTN_Q_B = auto()
673683
# vision
674684
V_MMPROJ = auto()
675685
V_MMPROJ_FC = auto()
@@ -858,6 +868,7 @@ class MODEL_TENSOR(IntEnum):
858868
MODEL_ARCH.CHATGLM: "chatglm",
859869
MODEL_ARCH.GLM4: "glm4",
860870
MODEL_ARCH.GLM4_MOE: "glm4moe",
871+
MODEL_ARCH.GLM_DSA: "glm-dsa",
861872
MODEL_ARCH.BITNET: "bitnet",
862873
MODEL_ARCH.T5: "t5",
863874
MODEL_ARCH.T5ENCODER: "t5encoder",
@@ -1101,6 +1112,10 @@ class MODEL_TENSOR(IntEnum):
11011112
MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate",
11021113
MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down",
11031114
MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up",
1115+
MODEL_TENSOR.INDEXER_K_NORM: "blk.{bid}.indexer.k_norm",
1116+
MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj",
1117+
MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k",
1118+
MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b",
11041119
# vision
11051120
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
11061121
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@@ -2677,6 +2692,47 @@ class MODEL_TENSOR(IntEnum):
26772692
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
26782693
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
26792694
],
2695+
MODEL_ARCH.GLM_DSA: [
2696+
MODEL_TENSOR.TOKEN_EMBD,
2697+
MODEL_TENSOR.OUTPUT_NORM,
2698+
MODEL_TENSOR.OUTPUT,
2699+
MODEL_TENSOR.ROPE_FREQS,
2700+
MODEL_TENSOR.ATTN_NORM,
2701+
MODEL_TENSOR.ATTN_Q,
2702+
MODEL_TENSOR.ATTN_Q_A,
2703+
MODEL_TENSOR.ATTN_Q_B,
2704+
MODEL_TENSOR.ATTN_KV_A_MQA,
2705+
MODEL_TENSOR.ATTN_KV_B,
2706+
MODEL_TENSOR.ATTN_K_B,
2707+
MODEL_TENSOR.ATTN_V_B,
2708+
MODEL_TENSOR.ATTN_Q_A_NORM,
2709+
MODEL_TENSOR.ATTN_KV_A_NORM,
2710+
MODEL_TENSOR.ATTN_OUT,
2711+
MODEL_TENSOR.ATTN_ROT_EMBD,
2712+
MODEL_TENSOR.FFN_GATE_INP,
2713+
MODEL_TENSOR.FFN_NORM,
2714+
MODEL_TENSOR.FFN_GATE,
2715+
MODEL_TENSOR.FFN_DOWN,
2716+
MODEL_TENSOR.FFN_UP,
2717+
MODEL_TENSOR.FFN_GATE_EXP,
2718+
MODEL_TENSOR.FFN_DOWN_EXP,
2719+
MODEL_TENSOR.FFN_UP_EXP,
2720+
MODEL_TENSOR.FFN_GATE_SHEXP,
2721+
MODEL_TENSOR.FFN_DOWN_SHEXP,
2722+
MODEL_TENSOR.FFN_UP_SHEXP,
2723+
MODEL_TENSOR.FFN_EXP_PROBS_B,
2724+
MODEL_TENSOR.INDEXER_K_NORM,
2725+
MODEL_TENSOR.INDEXER_PROJ,
2726+
MODEL_TENSOR.INDEXER_ATTN_K,
2727+
MODEL_TENSOR.INDEXER_ATTN_Q_B,
2728+
# NextN/MTP tensors - preserved but unused
2729+
MODEL_TENSOR.NEXTN_EH_PROJ,
2730+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2731+
MODEL_TENSOR.NEXTN_ENORM,
2732+
MODEL_TENSOR.NEXTN_HNORM,
2733+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2734+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
2735+
],
26802736
MODEL_ARCH.BITNET: [
26812737
MODEL_TENSOR.ATTN_Q,
26822738
MODEL_TENSOR.ATTN_K,

gguf-py/gguf/gguf_writer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,15 @@ def add_key_length_mla(self, length: int) -> None:
771771
def add_value_length_mla(self, length: int) -> None:
772772
self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length)
773773

774+
def add_indexer_head_count(self, count: int) -> None:
775+
self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count)
776+
777+
def add_indexer_key_length(self, length: int) -> None:
778+
self.add_uint32(Keys.Attention.Indexer.KEY_LENGTH.format(arch=self.arch), length)
779+
780+
def add_indexer_top_k(self, top_k: int) -> None:
781+
self.add_uint32(Keys.Attention.Indexer.TOP_K.format(arch=self.arch), top_k)
782+
774783
def add_max_alibi_bias(self, bias: float) -> None:
775784
self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias)
776785

gguf-py/gguf/tensor_mapping.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,22 @@ class TensorNameMap:
12061206
"model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm
12071207
),
12081208

1209+
MODEL_TENSOR.INDEXER_K_NORM: (
1210+
"model.layers.{bid}.self_attn.indexer.k_norm", # DSA
1211+
),
1212+
1213+
MODEL_TENSOR.INDEXER_PROJ: (
1214+
"model.layers.{bid}.self_attn.indexer.weights_proj", # DSA
1215+
),
1216+
1217+
MODEL_TENSOR.INDEXER_ATTN_K: (
1218+
"model.layers.{bid}.self_attn.indexer.wk", # DSA
1219+
),
1220+
1221+
MODEL_TENSOR.INDEXER_ATTN_Q_B: (
1222+
"model.layers.{bid}.self_attn.indexer.wq_b", # DSA
1223+
),
1224+
12091225
############################################################################
12101226
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
12111227
MODEL_TENSOR.ENC_OUTPUT_NORM: (

src/llama-arch.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
7474
{ LLM_ARCH_CHATGLM, "chatglm" },
7575
{ LLM_ARCH_GLM4, "glm4" },
7676
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
77+
{ LLM_ARCH_GLM_DSA, "glm-dsa" },
7778
{ LLM_ARCH_BITNET, "bitnet" },
7879
{ LLM_ARCH_T5, "t5" },
7980
{ LLM_ARCH_T5ENCODER, "t5encoder" },
@@ -225,6 +226,9 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
225226
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
226227
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
227228
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
229+
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
230+
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
231+
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
228232

229233
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
230234
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -516,6 +520,10 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
516520
{ LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" },
517521
{ LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" },
518522
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
523+
{ LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" },
524+
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
525+
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
526+
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
519527
};
520528

521529
static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
@@ -1657,6 +1665,46 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
16571665
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
16581666
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
16591667
};
1668+
case LLM_ARCH_GLM_DSA:
1669+
return {
1670+
LLM_TENSOR_TOKEN_EMBD,
1671+
LLM_TENSOR_OUTPUT_NORM,
1672+
LLM_TENSOR_OUTPUT,
1673+
LLM_TENSOR_ATTN_NORM,
1674+
LLM_TENSOR_ATTN_Q_A_NORM,
1675+
LLM_TENSOR_ATTN_KV_A_NORM,
1676+
LLM_TENSOR_ATTN_Q,
1677+
LLM_TENSOR_ATTN_Q_A,
1678+
LLM_TENSOR_ATTN_Q_B,
1679+
LLM_TENSOR_ATTN_KV_A_MQA,
1680+
LLM_TENSOR_ATTN_KV_B,
1681+
LLM_TENSOR_ATTN_K_B,
1682+
LLM_TENSOR_ATTN_V_B,
1683+
LLM_TENSOR_ATTN_OUT,
1684+
LLM_TENSOR_FFN_NORM,
1685+
LLM_TENSOR_FFN_GATE,
1686+
LLM_TENSOR_FFN_UP,
1687+
LLM_TENSOR_FFN_DOWN,
1688+
LLM_TENSOR_FFN_GATE_INP,
1689+
LLM_TENSOR_FFN_GATE_EXPS,
1690+
LLM_TENSOR_FFN_DOWN_EXPS,
1691+
LLM_TENSOR_FFN_UP_EXPS,
1692+
LLM_TENSOR_FFN_GATE_INP_SHEXP,
1693+
LLM_TENSOR_FFN_GATE_SHEXP,
1694+
LLM_TENSOR_FFN_DOWN_SHEXP,
1695+
LLM_TENSOR_FFN_UP_SHEXP,
1696+
LLM_TENSOR_FFN_EXP_PROBS_B,
1697+
LLM_TENSOR_INDEXER_K_NORM,
1698+
LLM_TENSOR_INDEXER_PROJ,
1699+
LLM_TENSOR_INDEXER_ATTN_K,
1700+
LLM_TENSOR_INDEXER_ATTN_Q_B,
1701+
LLM_TENSOR_NEXTN_EH_PROJ,
1702+
LLM_TENSOR_NEXTN_EMBED_TOKENS,
1703+
LLM_TENSOR_NEXTN_ENORM,
1704+
LLM_TENSOR_NEXTN_HNORM,
1705+
LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD,
1706+
LLM_TENSOR_NEXTN_SHARED_HEAD_NORM,
1707+
};
16601708
case LLM_ARCH_BITNET:
16611709
return {
16621710
LLM_TENSOR_TOKEN_EMBD,
@@ -2643,6 +2691,10 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
26432691
{LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
26442692
{LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
26452693
{LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2694+
{LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2695+
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2696+
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2697+
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
26462698
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
26472699
// These tensors only exist in the last layer(s) and are treated as output tensors
26482700
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},

src/llama-arch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ enum llm_arch {
7878
LLM_ARCH_CHATGLM,
7979
LLM_ARCH_GLM4,
8080
LLM_ARCH_GLM4_MOE,
81+
LLM_ARCH_GLM_DSA,
8182
LLM_ARCH_BITNET,
8283
LLM_ARCH_T5,
8384
LLM_ARCH_T5ENCODER,
@@ -229,6 +230,9 @@ enum llm_kv {
229230
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
230231
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
231232
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
233+
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
234+
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
235+
LLM_KV_ATTENTION_INDEXER_TOP_K,
232236

233237
LLM_KV_ROPE_DIMENSION_COUNT,
234238
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -517,6 +521,10 @@ enum llm_tensor {
517521
LLM_TENSOR_VISEXP_FFN_GATE,
518522
LLM_TENSOR_VISEXP_FFN_DOWN,
519523
LLM_TENSOR_VISEXP_FFN_UP,
524+
LLM_TENSOR_INDEXER_K_NORM,
525+
LLM_TENSOR_INDEXER_PROJ,
526+
LLM_TENSOR_INDEXER_ATTN_K,
527+
LLM_TENSOR_INDEXER_ATTN_Q_B,
520528
LLM_TENSOR_NEXTN_EH_PROJ,
521529
LLM_TENSOR_NEXTN_EMBED_TOKENS,
522530
LLM_TENSOR_NEXTN_ENORM,

src/llama-hparams.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ struct llama_hparams {
193193
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
194194
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;
195195

196+
// DSA (deepseek sparse attention)
197+
uint32_t indexer_n_head = 0;
198+
uint32_t indexer_head_size = 0;
199+
uint32_t indexer_top_k = 0;
200+
196201
// qwen3vl deepstack
197202
uint32_t n_deepstack_layers = 0;
198203

0 commit comments

Comments
 (0)