Skip to content

Commit d34ff7e

Browse files
ngxsonCISC
andauthored
model: mistral small 4 support (ggml-org#20649)
* model: mistral small 4 support * fix test * fix test (2) * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * change newline --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
1 parent 45172df commit d34ff7e

6 files changed

Lines changed: 133 additions & 42 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,16 @@ def dequant_simple(weight: Tensor, scale: Tensor, block_size: Sequence[int] | No
298298
scale = scale.float()
299299

300300
if block_size is not None:
301+
dim_offset = scale.ndim - len(block_size)
301302
for i, size in enumerate(block_size):
302-
scale = scale.repeat_interleave(size, i)
303+
scale = scale.repeat_interleave(size, dim_offset + i)
303304
# unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
304305
scale = scale[tuple(slice(0, size) for size in weight.shape)]
305306

307+
# align scale dims to weight for correct broadcasting (e.g. [128] -> [128, 1, 1])
308+
while scale.ndim < weight.ndim:
309+
scale = scale.unsqueeze(-1)
310+
306311
return weight.float() * scale
307312

308313
# ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
@@ -393,14 +398,16 @@ def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: T
393398
elif quant_method == "fp8":
394399
block_size = quant_config.get("weight_block_size")
395400
for name in self.model_tensors.keys():
396-
if name.endswith(".weight_scale_inv"):
401+
if name.endswith("_scale_inv"):
397402
weight_name = name.removesuffix("_scale_inv")
398403
w = self.model_tensors[weight_name]
399404
s = self.model_tensors[name]
400405
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
401406
tensors_to_remove.append(name)
402407
if name.endswith(".activation_scale"): # unused
403408
tensors_to_remove.append(name)
409+
if name.endswith("_activation_scale"): # Mistral-Small-4-119B-2602, unused
410+
tensors_to_remove.append(name)
404411
# mistral format
405412
if name.endswith(".qscale_weight"):
406413
weight_name = name.removesuffix("qscale_weight") + "weight"
@@ -3031,10 +3038,16 @@ def __init__(self, *args, **kwargs):
30313038
def get_token_id(self, token: str) -> int:
30323039
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
30333040
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
3034-
added_tokens_decoder = json.load(f)['added_tokens_decoder']
3041+
added_tokens_decoder = json.load(f).get('added_tokens_decoder') or {}
30353042
for id_, token_data in added_tokens_decoder.items():
3036-
if token_data["content"] == token:
3043+
if token_data.get("content") == token:
30373044
return int(id_)
3045+
# fallthrough to tokenizer.json
3046+
with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f:
3047+
tokenizer_json = json.load(f)
3048+
for token_data in tokenizer_json["added_tokens"]:
3049+
if token_data["content"] == token:
3050+
return int(token_data["id"])
30383051
raise ValueError(f"Token '{token}' not found in tokenizer config.")
30393052

30403053
def set_gguf_parameters(self):
@@ -3198,40 +3211,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
31983211
yield from super().modify_tensors(data_torch, name, bid)
31993212

32003213

3201-
@ModelBase.register(
3202-
"Mistral3ForConditionalGeneration",
3203-
"Ministral3ForCausalLM",
3204-
)
3205-
class Mistral3Model(LlamaModel):
3206-
model_arch = gguf.MODEL_ARCH.MISTRAL3
3207-
3208-
def __init__(self, *args, **kwargs):
3209-
super().__init__(*args, **kwargs)
3210-
# for compatibility, we use LLAMA arch for older models
3211-
# TODO: remove this once everyone has migrated to newer version of llama.cpp
3212-
if self.hparams.get("model_type") != "ministral3":
3213-
self.model_arch = gguf.MODEL_ARCH.LLAMA
3214-
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
3215-
self.gguf_writer.add_architecture()
3216-
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
3217-
3218-
def set_gguf_parameters(self):
3219-
super().set_gguf_parameters()
3220-
rope_params = self.rope_parameters
3221-
if self.hparams.get("model_type") == "ministral3":
3222-
assert rope_params, "ministral3 must have 'rope_parameters' config"
3223-
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
3224-
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
3225-
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
3226-
3227-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
3228-
name = name.replace("language_model.", "")
3229-
if "multi_modal_projector" in name or "vision_tower" in name:
3230-
return
3231-
3232-
yield from super().modify_tensors(data_torch, name, bid)
3233-
3234-
32353214
@ModelBase.register("DeciLMForCausalLM")
32363215
class DeciModel(TextModel):
32373216
model_arch = gguf.MODEL_ARCH.DECI
@@ -8271,6 +8250,8 @@ class DeepseekV2Model(TextModel):
82718250
# TODO @ngxson : remove this when we support MTP for deepseek models
82728251
skip_mtp = True
82738252

8253+
merge_expert = True
8254+
82748255
def set_vocab(self):
82758256
try:
82768257
self._set_vocab_gpt2()
@@ -8409,7 +8390,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
84098390
return
84108391

84118392
# process the experts separately
8412-
if name.find("mlp.experts") != -1:
8393+
if self.merge_expert and name.find("mlp.experts") != -1:
84138394
n_experts = self.hparams["n_routed_experts"]
84148395
assert bid is not None
84158396

@@ -8468,6 +8449,69 @@ def prepare_tensors(self):
84688449
raise ValueError(f"Unprocessed experts: {experts}")
84698450

84708451

8452+
@ModelBase.register(
8453+
"Mistral3ForConditionalGeneration",
8454+
"Ministral3ForCausalLM",
8455+
)
8456+
class Mistral3Model(TextModel):
8457+
class Ministral3Model(LlamaModel):
8458+
model_arch = gguf.MODEL_ARCH.MISTRAL3
8459+
8460+
def set_gguf_parameters(self):
8461+
super().set_gguf_parameters()
8462+
rope_params = self.rope_parameters
8463+
if self.hparams.get("model_type") == "ministral3":
8464+
assert rope_params, "ministral3 must have 'rope_parameters' config"
8465+
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
8466+
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
8467+
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
8468+
8469+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
8470+
name = name.replace("language_model.", "")
8471+
if "multi_modal_projector" in name or "vision_tower" in name:
8472+
return
8473+
8474+
yield from super().modify_tensors(data_torch, name, bid)
8475+
8476+
class Mistral4Model(DeepseekV2Model):
8477+
model_arch = gguf.MODEL_ARCH.MISTRAL4
8478+
skip_mtp = False # model contains no MTP layers, so no need to skip
8479+
merge_expert = False # experts are already stacked as 3D
8480+
8481+
def modify_tensors(self, data_torch, name, bid):
8482+
if name.endswith(".down_proj") or name.endswith(".gate_up_proj"):
8483+
name = name + ".weight"
8484+
yield from super().modify_tensors(data_torch, name, bid)
8485+
8486+
model_arch = gguf.MODEL_ARCH.MISTRAL3 # unused
8487+
impl: TextModel
8488+
8489+
def __init__(self, *args, **kwargs):
8490+
super().__init__(*args, **kwargs)
8491+
if self.hparams.get("model_type") == "mistral4":
8492+
self.impl = Mistral3Model.Mistral4Model(*args, **kwargs)
8493+
else:
8494+
self.impl = Mistral3Model.Ministral3Model(*args, **kwargs)
8495+
8496+
def set_vocab(self):
8497+
self.impl.set_vocab()
8498+
8499+
def set_gguf_parameters(self):
8500+
self.impl.set_gguf_parameters()
8501+
8502+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
8503+
yield from self.impl.modify_tensors(data_torch, name, bid)
8504+
8505+
def prepare_tensors(self):
8506+
self.impl.prepare_tensors()
8507+
8508+
def write_vocab(self):
8509+
self.impl.write_vocab()
8510+
8511+
def write(self):
8512+
self.impl.write()
8513+
8514+
84718515
@ModelBase.register("MiniMaxM2ForCausalLM")
84728516
class MiniMaxM2Model(TextModel):
84738517
model_arch = gguf.MODEL_ARCH.MINIMAXM2

gguf-py/gguf/constants.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ class MODEL_ARCH(IntEnum):
478478
RND1 = auto()
479479
PANGU_EMBED = auto()
480480
MISTRAL3 = auto()
481+
MISTRAL4 = auto()
481482
PADDLEOCR = auto()
482483
MIMO2 = auto()
483484
STEP35 = auto()
@@ -924,6 +925,7 @@ class MODEL_TENSOR(IntEnum):
924925
MODEL_ARCH.RND1: "rnd1",
925926
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
926927
MODEL_ARCH.MISTRAL3: "mistral3",
928+
MODEL_ARCH.MISTRAL4: "mistral4",
927929
MODEL_ARCH.PADDLEOCR: "paddleocr",
928930
MODEL_ARCH.MIMO2: "mimo2",
929931
MODEL_ARCH.STEP35: "step35",
@@ -3538,6 +3540,37 @@ class MODEL_TENSOR(IntEnum):
35383540
MODEL_TENSOR.FFN_DOWN_EXP,
35393541
MODEL_TENSOR.FFN_UP_EXP,
35403542
],
3543+
MODEL_ARCH.MISTRAL4: [
3544+
MODEL_TENSOR.TOKEN_EMBD,
3545+
MODEL_TENSOR.OUTPUT_NORM,
3546+
MODEL_TENSOR.OUTPUT,
3547+
MODEL_TENSOR.ROPE_FREQS,
3548+
MODEL_TENSOR.ATTN_NORM,
3549+
MODEL_TENSOR.ATTN_Q,
3550+
MODEL_TENSOR.ATTN_Q_A,
3551+
MODEL_TENSOR.ATTN_Q_B,
3552+
MODEL_TENSOR.ATTN_KV_A_MQA,
3553+
MODEL_TENSOR.ATTN_KV_B,
3554+
MODEL_TENSOR.ATTN_K_B,
3555+
MODEL_TENSOR.ATTN_V_B,
3556+
MODEL_TENSOR.ATTN_Q_A_NORM,
3557+
MODEL_TENSOR.ATTN_KV_A_NORM,
3558+
MODEL_TENSOR.ATTN_OUT,
3559+
MODEL_TENSOR.ATTN_ROT_EMBD,
3560+
MODEL_TENSOR.FFN_GATE_INP,
3561+
MODEL_TENSOR.FFN_NORM,
3562+
MODEL_TENSOR.FFN_GATE,
3563+
MODEL_TENSOR.FFN_DOWN,
3564+
MODEL_TENSOR.FFN_UP,
3565+
MODEL_TENSOR.FFN_GATE_EXP,
3566+
MODEL_TENSOR.FFN_DOWN_EXP,
3567+
MODEL_TENSOR.FFN_UP_EXP,
3568+
MODEL_TENSOR.FFN_GATE_UP_EXP,
3569+
MODEL_TENSOR.FFN_GATE_SHEXP,
3570+
MODEL_TENSOR.FFN_DOWN_SHEXP,
3571+
MODEL_TENSOR.FFN_UP_SHEXP,
3572+
MODEL_TENSOR.FFN_EXP_PROBS_B,
3573+
],
35413574
MODEL_ARCH.MIMO2: [
35423575
MODEL_TENSOR.TOKEN_EMBD,
35433576
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
123123
{ LLM_ARCH_RND1, "rnd1" },
124124
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
125125
{ LLM_ARCH_MISTRAL3, "mistral3" },
126+
{ LLM_ARCH_MISTRAL4, "mistral4" },
126127
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
127128
{ LLM_ARCH_MIMO2, "mimo2" },
128129
{ LLM_ARCH_STEP35, "step35" },
@@ -1589,6 +1590,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
15891590
LLM_TENSOR_FFN_UP_SHEXP,
15901591
};
15911592
case LLM_ARCH_DEEPSEEK2:
1593+
case LLM_ARCH_MISTRAL4:
15921594
return {
15931595
LLM_TENSOR_TOKEN_EMBD,
15941596
LLM_TENSOR_OUTPUT_NORM,

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ enum llm_arch {
127127
LLM_ARCH_RND1,
128128
LLM_ARCH_PANGU_EMBED,
129129
LLM_ARCH_MISTRAL3,
130+
LLM_ARCH_MISTRAL4,
130131
LLM_ARCH_PADDLEOCR,
131132
LLM_ARCH_MIMO2,
132133
LLM_ARCH_STEP35,

src/llama-model.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1587,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15871587
}
15881588
} break;
15891589
case LLM_ARCH_DEEPSEEK2:
1590+
case LLM_ARCH_MISTRAL4:
15901591
{
15911592
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B
15921593
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256));
@@ -4883,6 +4884,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
48834884
}
48844885
} break;
48854886
case LLM_ARCH_DEEPSEEK2:
4887+
case LLM_ARCH_MISTRAL4:
48864888
{
48874889
const bool is_mla = hparams.is_mla();
48884890

@@ -7850,7 +7852,7 @@ void llama_model::print_info() const {
78507852
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
78517853
}
78527854

7853-
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) {
7855+
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) {
78547856
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
78557857
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
78567858
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@@ -8428,6 +8430,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
84288430
} break;
84298431
case LLM_ARCH_DEEPSEEK2:
84308432
case LLM_ARCH_GLM_DSA:
8433+
case LLM_ARCH_MISTRAL4:
84318434
{
84328435
llm = std::make_unique<llm_build_deepseek2>(*this, params);
84338436
} break;
@@ -8839,6 +8842,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
88398842
case LLM_ARCH_ERNIE4_5:
88408843
case LLM_ARCH_ERNIE4_5_MOE:
88418844
case LLM_ARCH_MISTRAL3:
8845+
case LLM_ARCH_MISTRAL4:
88428846
case LLM_ARCH_LLAMA_EMBED:
88438847
case LLM_ARCH_MAINCODER:
88448848
case LLM_ARCH_GLM_DSA:

tests/test-llama-archs.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
9090
n_embd = 64;
9191
n_head = 1;
9292
n_ff = 96;
93-
} else if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR) {
93+
} else if (arch == LLM_ARCH_DEEPSEEK2
94+
|| arch == LLM_ARCH_GLM_DSA
95+
|| arch == LLM_ARCH_KIMI_LINEAR
96+
|| arch == LLM_ARCH_MISTRAL4) {
9497
n_embd = 128;
9598
n_head = 1;
9699
n_ff = 192;
@@ -145,7 +148,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
145148
}
146149

147150
ms.add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, 8.0f);
148-
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR) {
151+
if (arch == LLM_ARCH_DEEPSEEK2
152+
|| arch == LLM_ARCH_GLM_DSA
153+
|| arch == LLM_ARCH_KIMI_LINEAR
154+
|| arch == LLM_ARCH_MISTRAL4) {
149155
ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH, uint32_t(576));
150156
ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, uint32_t(512));
151157
ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT, uint32_t(64));
@@ -319,6 +325,7 @@ static bool moe_mandatory(const llm_arch arch) {
319325
case LLM_ARCH_MIMO2:
320326
case LLM_ARCH_KIMI_LINEAR:
321327
case LLM_ARCH_STEP35:
328+
case LLM_ARCH_MISTRAL4:
322329
return true;
323330
default:
324331
return false;

0 commit comments

Comments
 (0)