Skip to content

Commit d23355a

Browse files
authored
model : wire up Qwen3.5/Qwen3.5MoE tensors for NVFP4 support (ggml-org#20506)
1 parent b30a5fd commit d23355a

4 files changed

Lines changed: 62 additions & 25 deletions

File tree

src/llama-model.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7462,6 +7462,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
74627462
if (!layer.wo_s && layer.wo) {
74637463
layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
74647464
}
7465+
if (!layer.wqkv_s && layer.wqkv) {
7466+
layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7467+
}
7468+
if (!layer.wqkv_gate_s && layer.wqkv_gate) {
7469+
layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7470+
}
74657471

74667472
// dense FFN weight scales (per-tensor, shape {1})
74677473
if (!layer.ffn_gate_s && layer.ffn_gate) {
@@ -7473,6 +7479,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
74737479
if (!layer.ffn_up_s && layer.ffn_up) {
74747480
layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
74757481
}
7482+
if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) {
7483+
layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7484+
}
7485+
if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) {
7486+
layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7487+
}
7488+
if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) {
7489+
layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7490+
}
74767491

74777492
// MoE expert weight scales (per-expert, shape {n_expert})
74787493
if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) {
@@ -7484,6 +7499,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
74847499
if (!layer.ffn_up_exps_s && layer.ffn_up_exps) {
74857500
layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED);
74867501
}
7502+
7503+
// recurrent / linear-attention weight scales (per-tensor, shape {1})
7504+
if (!layer.ssm_out_s && layer.ssm_out) {
7505+
layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7506+
}
7507+
if (!layer.ssm_alpha_s && layer.ssm_alpha) {
7508+
layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7509+
}
7510+
if (!layer.ssm_beta_s && layer.ssm_beta) {
7511+
layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED);
7512+
}
74877513
}
74887514
}
74897515

src/llama-model.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,17 @@ struct llama_layer {
401401
struct ggml_tensor * wk_s = nullptr;
402402
struct ggml_tensor * wv_s = nullptr;
403403
struct ggml_tensor * wo_s = nullptr;
404+
struct ggml_tensor * wqkv_s = nullptr;
405+
struct ggml_tensor * wqkv_gate_s = nullptr;
404406
struct ggml_tensor * ffn_gate_s = nullptr;
405407
struct ggml_tensor * ffn_up_s = nullptr;
406408
struct ggml_tensor * ffn_down_s = nullptr;
409+
struct ggml_tensor * ffn_gate_shexp_s = nullptr;
410+
struct ggml_tensor * ffn_up_shexp_s = nullptr;
411+
struct ggml_tensor * ffn_down_shexp_s = nullptr;
412+
struct ggml_tensor * ssm_out_s = nullptr;
413+
struct ggml_tensor * ssm_alpha_s = nullptr;
414+
struct ggml_tensor * ssm_beta_s = nullptr;
407415

408416
// altup & laurel
409417
struct ggml_tensor * per_layer_inp_gate = nullptr;

src/models/qwen35.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35::build_qkvz(
9090
const int64_t n_seqs = ubatch.n_seqs;
9191
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
9292

93-
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
93+
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s);
9494
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
9595
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
9696

97-
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
97+
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s);
9898
cb(z, "z", il);
9999

100100
return { qkv_mixed, z };
@@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn(
123123
// Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
124124

125125
// Qwen3Next uses a single Q projection that outputs query + gate
126-
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
126+
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ]
127127
cb(Qcur_full, "Qcur_full", il);
128128

129129
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
@@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn(
135135
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
136136
cb(Qcur, "Qcur_normed", il);
137137

138-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
138+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
139139
cb(Kcur, "Kcur", il);
140140

141-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
141+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
142142
cb(Vcur, "Vcur", il);
143143

144144
// Apply K normalization
@@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn(
186186
cur = ggml_mul(ctx0, cur, gate_sigmoid);
187187
cb(cur, "attn_gated", il);
188188

189-
cur = build_lora_mm(model.layers[il].wo, cur);
189+
cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
190190
cb(cur, "attn_output", il);
191191

192192
return cur;
@@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
217217
ggml_tensor * qkv_mixed = qkvz.first;
218218
ggml_tensor * z = qkvz.second;
219219

220-
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
220+
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s);
221221
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
222222
cb(beta, "beta", il);
223223

224224
beta = ggml_sigmoid(ctx0, beta);
225225

226-
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
226+
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s);
227227
alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
228228
cb(alpha, "alpha", il);
229229

@@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
356356
cb(final_output, "final_output", il);
357357

358358
// Output projection
359-
cur = build_lora_mm(model.layers[il].ssm_out, final_output);
359+
cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s);
360360
cb(cur, "linear_attn_out", il);
361361

362362
// Reshape back to original dimensions
@@ -370,9 +370,9 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il)
370370
GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr);
371371

372372
cur = build_ffn(cur,
373-
model.layers[il].ffn_up, NULL, NULL,
374-
model.layers[il].ffn_gate, NULL, NULL,
375-
model.layers[il].ffn_down, NULL, NULL,
373+
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s,
374+
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s,
375+
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s,
376376
NULL,
377377
LLM_FFN_SILU, LLM_FFN_PAR, il);
378378
cb(cur, "ffn_out", il);

src/models/qwen35moe.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen35moe::build_qkvz(
9090
const int64_t n_seqs = ubatch.n_seqs;
9191
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
9292

93-
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input);
93+
ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s);
9494
qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs);
9595
cb(qkv_mixed, "linear_attn_qkv_mixed", il);
9696

97-
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input);
97+
ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s);
9898
cb(z, "z", il);
9999

100100
return { qkv_mixed, z };
@@ -123,7 +123,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
123123
// Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
124124

125125
// Qwen3Next uses a single Q projection that outputs query + gate
126-
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ]
126+
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ]
127127
cb(Qcur_full, "Qcur_full", il);
128128

129129
ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens,
@@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
135135
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
136136
cb(Qcur, "Qcur_normed", il);
137137

138-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
138+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s);
139139
cb(Kcur, "Kcur", il);
140140

141-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
141+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s);
142142
cb(Vcur, "Vcur", il);
143143

144144
// Apply K normalization
@@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn(
186186
cur = ggml_mul(ctx0, cur, gate_sigmoid);
187187
cb(cur, "attn_gated", il);
188188

189-
cur = build_lora_mm(model.layers[il].wo, cur);
189+
cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s);
190190
cb(cur, "attn_output", il);
191191

192192
return cur;
@@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
217217
ggml_tensor * qkv_mixed = qkvz.first;
218218
ggml_tensor * z = qkvz.second;
219219

220-
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur);
220+
ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s);
221221
beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs);
222222
cb(beta, "beta", il);
223223

224224
beta = ggml_sigmoid(ctx0, beta);
225225

226-
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur);
226+
ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s);
227227
alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs);
228228
cb(alpha, "alpha", il);
229229

@@ -356,7 +356,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
356356
cb(final_output, "final_output", il);
357357

358358
// Output projection
359-
cur = build_lora_mm(model.layers[il].ssm_out, final_output);
359+
cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s);
360360
cb(cur, "linear_attn_out", il);
361361

362362
// Reshape back to original dimensions
@@ -380,16 +380,19 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int
380380
LLM_FFN_SILU, true,
381381
hparams.expert_weights_scale,
382382
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
383-
nullptr, model.layers[il].ffn_gate_up_exps);
383+
nullptr, model.layers[il].ffn_gate_up_exps,
384+
model.layers[il].ffn_up_exps_s,
385+
model.layers[il].ffn_gate_exps_s,
386+
model.layers[il].ffn_down_exps_s);
384387
cb(moe_out, "ffn_moe_out", il);
385388

386389
// Add shared experts if present - following Qwen3Next reference implementation
387390
if (model.layers[il].ffn_up_shexp != nullptr) {
388391
ggml_tensor * ffn_shexp =
389392
build_ffn(cur,
390-
model.layers[il].ffn_up_shexp, NULL, NULL,
391-
model.layers[il].ffn_gate_shexp, NULL, NULL,
392-
model.layers[il].ffn_down_shexp, NULL, NULL,
393+
model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s,
394+
model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s,
395+
model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s,
393396
NULL,
394397
LLM_FFN_SILU, LLM_FFN_PAR, il);
395398
cb(ffn_shexp, "ffn_shexp", il);

0 commit comments

Comments
 (0)