Skip to content

Commit f86777c

Browse files
authored
Fix nan by rescheduling attention scaling (#322)
1 parent e9989b5 commit f86777c

9 files changed

Lines changed: 115 additions & 97 deletions

File tree

README.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,15 @@ The original model (`-i <model_name_or_path>`) can be a Hugging Face model name
5959
* CodeGeeX2: `THUDM/codegeex2-6b`, `THUDM/codegeex2-6b-int4`
6060

6161
You are free to try any of the below quantization types by specifying `-t <type>`:
62-
* `q4_0`: 4-bit integer quantization with fp16 scales.
63-
* `q4_1`: 4-bit integer quantization with fp16 scales and minimum values.
64-
* `q5_0`: 5-bit integer quantization with fp16 scales.
65-
* `q5_1`: 5-bit integer quantization with fp16 scales and minimum values.
66-
* `q8_0`: 8-bit integer quantization with fp16 scales.
67-
* `f16`: half precision floating point weights without quantization.
68-
* `f32`: single precision floating point weights without quantization.
62+
| type | precision | symmetric |
63+
| ------ | --------- | --------- |
64+
| `q4_0` | int4 | true |
65+
| `q4_1` | int4 | false |
66+
| `q5_0` | int5 | true |
67+
| `q5_1` | int5 | false |
68+
| `q8_0` | int8 | true |
69+
| `f16` | half | |
70+
| `f32` | float | |
6971

7072
For LoRA models, add `-l <lora_model_name_or_path>` flag to merge your LoRA weights into the base model. For example, run `python3 chatglm_cpp/convert.py -i THUDM/chatglm3-6b -t q4_0 -o models/chatglm3-ggml-lora.bin -l shibing624/chatglm3-6b-csc-chinese-lora` to merge public LoRA weights from Hugging Face.
7173

@@ -551,8 +553,8 @@ Download and unzip the dataset from [link](https://s3.amazonaws.com/research.met
551553
552554
| | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F16 |
553555
|-------------------------|-------|-------|-------|-------|-------|-------|
554-
| [ChatGLM3-6B-Base][1] | 6.215 | 6.184 | 5.997 | 6.015 | 5.965 | 5.971 |
555-
| [ChatGLM4-9B-Base][2] | 6.851 | 6.793 | 6.652 | 6.635 | 6.582 | 6.586 |
556+
| [ChatGLM3-6B-Base][1] | 6.215 | 6.188 | 6.006 | 6.022 | 5.971 | 5.972 |
557+
| [ChatGLM4-9B-Base][2] | 6.834 | 6.780 | 6.645 | 6.624 | 6.576 | 6.577 |
556558
557559
[1]: https://huggingface.co/THUDM/chatglm3-6b-base
558560
[2]: https://huggingface.co/THUDM/glm-4-9b

chatglm.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
624624
const int hidden_size = hidden_states->ne[0];
625625
const int qlen = hidden_states->ne[1];
626626
const int head_size = hidden_size / num_attention_heads;
627-
const int num_shared_q_heads = num_attention_heads / num_kv_heads;
627+
const int num_shared_q_heads = num_attention_heads / num_key_value_heads;
628628

629629
ggml_tensor *qkv = query_key_value.forward(mctx, hidden_states); // [sq, (#h + 2 * #kvh) * d]
630630

@@ -645,10 +645,11 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
645645
} else {
646646
query_layer = ggml_view_3d(ctx, qkv, head_size, num_attention_heads, qlen, head_size * ggml_element_size(qkv),
647647
qkv->nb[1], 0);
648-
key_layer = ggml_view_3d(ctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size(qkv),
648+
key_layer = ggml_view_3d(ctx, qkv, head_size, num_key_value_heads, qlen, head_size * ggml_element_size(qkv),
649649
qkv->nb[1], hidden_size * ggml_element_size(qkv));
650-
value_layer = ggml_view_3d(ctx, qkv, head_size, num_kv_heads, qlen, head_size * ggml_element_size(qkv),
651-
qkv->nb[1], (hidden_size + head_size * num_kv_heads) * ggml_element_size(qkv));
650+
value_layer =
651+
ggml_view_3d(ctx, qkv, head_size, num_key_value_heads, qlen, head_size * ggml_element_size(qkv), qkv->nb[1],
652+
(hidden_size + head_size * num_key_value_heads) * ggml_element_size(qkv));
652653
}
653654

654655
query_layer = apply_rotary_emb(mctx, query_layer, position_ids, rope_type, rope_theta);
@@ -657,33 +658,33 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
657658
query_layer = ggml_cont(ctx, ggml_permute(ctx, query_layer, 0, 2, 1, 3)); // [#h, s, d]
658659
if (num_shared_q_heads > 1) {
659660
query_layer = ggml_reshape_3d(ctx, query_layer, head_size, num_shared_q_heads * qlen,
660-
num_kv_heads); // [#kvh, (#h/#kvh) * s, d]
661+
num_key_value_heads); // [#kvh, (#h/#kvh) * s, d]
661662
}
662663

663664
key_layer = ggml_permute(ctx, key_layer, 0, 2, 1, 3); // [#kvh, s, d]
664665
value_layer = ggml_permute(ctx, value_layer, 1, 2, 0, 3); // [#kvh, d, s]
665666

666667
// store key & value to cache
667668
ggml_tensor *k_cache_view =
668-
ggml_view_3d(ctx, k_cache, head_size, qlen, num_kv_heads, k_cache->nb[1], k_cache->nb[2],
669+
ggml_view_3d(ctx, k_cache, head_size, qlen, num_key_value_heads, k_cache->nb[1], k_cache->nb[2],
669670
(num_virtual_tokens + n_past) * head_size * ggml_element_size(k_cache)); // [#kvh, s, d]
670671
ggml_build_forward_expand(mctx->gf, ggml_cpy(ctx, key_layer, k_cache_view));
671672
ggml_tensor *v_cache_view =
672-
ggml_view_3d(ctx, v_cache, qlen, head_size, num_kv_heads, v_cache->nb[1], v_cache->nb[2],
673+
ggml_view_3d(ctx, v_cache, qlen, head_size, num_key_value_heads, v_cache->nb[1], v_cache->nb[2],
673674
(num_virtual_tokens + n_past) * ggml_element_size(v_cache)); // [#kvh, d, s]
674675
ggml_build_forward_expand(mctx->gf, ggml_cpy(ctx, value_layer, v_cache_view));
675676

676677
// concat key & value with past kv
677-
key_layer = ggml_view_3d(ctx, k_cache, head_size, num_virtual_tokens + n_past + qlen, num_kv_heads, k_cache->nb[1],
678-
k_cache->nb[2],
678+
key_layer = ggml_view_3d(ctx, k_cache, head_size, num_virtual_tokens + n_past + qlen, num_key_value_heads,
679+
k_cache->nb[1], k_cache->nb[2],
679680
0); // [#kvh, kvs, d]
680-
value_layer = ggml_view_3d(ctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, num_kv_heads,
681+
value_layer = ggml_view_3d(ctx, v_cache, num_virtual_tokens + n_past + qlen, head_size, num_key_value_heads,
681682
v_cache->nb[1], v_cache->nb[2],
682683
0); // [#kvh, d, kvs]
683684

684685
// attention
686+
query_layer = ggml_scale_inplace(ctx, query_layer, 1.f / std::sqrt(head_size));
685687
ggml_tensor *attn_scores = ggml_mul_mat(ctx, key_layer, query_layer); // [#kvh, (#h/#kvh) * s, kvs]
686-
attn_scores = ggml_scale_inplace(ctx, attn_scores, 1.f / std::sqrt(head_size));
687688

688689
if (n_past == 0) {
689690
// build attention mask for context input
@@ -701,7 +702,7 @@ ggml_tensor *BasicAttention::forward(ModelContext *mctx, ggml_tensor *hidden_sta
701702
if (num_shared_q_heads > 1) {
702703
attn_scores =
703704
ggml_reshape_3d(ctx, attn_scores, num_virtual_tokens + n_past + qlen, num_shared_q_heads * qlen,
704-
num_kv_heads); // [#kvh, (#h/#kvh) * s, kvs]
705+
num_key_value_heads); // [#kvh, (#h/#kvh) * s, kvs]
705706
}
706707
}
707708

chatglm.h

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ struct ConfigRecordV1 {
6565

6666
// For compatibility
6767
struct ConfigRecordV1GQA : public ConfigRecordV1 {
68-
int num_kv_heads;
68+
int num_key_value_heads;
6969
};
7070

7171
// TODO: use json to serialize config
@@ -109,15 +109,15 @@ class ModelConfig {
109109
ModelConfig() = default;
110110

111111
ModelConfig(ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads,
112-
int num_kv_heads, int num_hidden_layers, int intermediate_size, float norm_eps, float rope_theta,
112+
int num_key_value_heads, int num_hidden_layers, int intermediate_size, float norm_eps, float rope_theta,
113113
int num_virtual_tokens, int max_length, int bos_token_id, int eos_token_id, int pad_token_id,
114114
int sep_token_id, std::vector<int> extra_eos_token_ids)
115115
: model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size),
116-
num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers),
117-
intermediate_size(intermediate_size), norm_eps(norm_eps), rope_theta(rope_theta),
118-
num_virtual_tokens(num_virtual_tokens), max_length(max_length), bos_token_id(bos_token_id),
119-
eos_token_id(eos_token_id), pad_token_id(pad_token_id), sep_token_id(sep_token_id),
120-
extra_eos_token_ids(std::move(extra_eos_token_ids)) {
116+
num_attention_heads(num_attention_heads), num_key_value_heads(num_key_value_heads),
117+
num_hidden_layers(num_hidden_layers), intermediate_size(intermediate_size), norm_eps(norm_eps),
118+
rope_theta(rope_theta), num_virtual_tokens(num_virtual_tokens), max_length(max_length),
119+
bos_token_id(bos_token_id), eos_token_id(eos_token_id), pad_token_id(pad_token_id),
120+
sep_token_id(sep_token_id), extra_eos_token_ids(std::move(extra_eos_token_ids)) {
121121
if (model_type == ModelType::CHATGLM) {
122122
hidden_act = ActivationType::GELU;
123123
use_qkv_bias = true;
@@ -146,9 +146,10 @@ class ModelConfig {
146146

147147
ModelConfig(ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, float rope_theta,
148148
int num_virtual_tokens)
149-
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads,
150-
rec.num_hidden_layers, rec.intermediate_size, norm_eps, rope_theta, num_virtual_tokens,
151-
rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
149+
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
150+
rec.num_key_value_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, rope_theta,
151+
num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
152+
rec.sep_token_id, {}) {}
152153

153154
ModelConfig(ModelType model_type, const ConfigRecordV2 &rec)
154155
: ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
@@ -158,13 +159,33 @@ class ModelConfig {
158159

159160
std::string model_type_name() const { return to_string(model_type); }
160161

162+
friend std::ostream &operator<<(std::ostream &os, const ModelConfig &self) {
163+
os << "ModelConfig(model_type=" << (int)self.model_type << ", dtype=" << self.dtype
164+
<< ", vocab_size=" << self.vocab_size << ", hidden_size=" << self.hidden_size
165+
<< ", num_attention_heads=" << self.num_attention_heads
166+
<< ", num_key_value_heads=" << self.num_key_value_heads << ", num_hidden_layers=" << self.num_hidden_layers
167+
<< ", intermediate_size=" << self.intermediate_size << ", norm_eps=" << self.norm_eps
168+
<< ", hidden_act=" << (int)self.hidden_act << ", use_qkv_bias=" << self.use_qkv_bias
169+
<< ", use_dense_bias=" << self.use_dense_bias << ", interleaved_qkv=" << self.interleaved_qkv
170+
<< ", tie_word_embeddings=" << self.tie_word_embeddings << ", rope_type=" << (int)self.rope_type
171+
<< ", rope_theta=" << self.rope_theta << ", attn_mask_type=" << (int)self.attn_mask_type
172+
<< ", num_virtual_tokens=" << self.num_virtual_tokens << ", max_length=" << self.max_length
173+
<< ", bos_token_id=" << self.bos_token_id << ", eos_token_id=" << self.eos_token_id
174+
<< ", pad_token_id=" << self.pad_token_id << ", sep_token_id=" << self.sep_token_id
175+
<< ", extra_eos_token_ids={";
176+
for (size_t i = 0; i < self.extra_eos_token_ids.size(); i++) {
177+
os << (i > 0 ? ", " : "") << self.extra_eos_token_ids[i];
178+
}
179+
return os << "})";
180+
}
181+
161182
public:
162183
ModelType model_type;
163184
ggml_type dtype;
164185
int vocab_size;
165186
int hidden_size;
166187
int num_attention_heads;
167-
int num_kv_heads;
188+
int num_key_value_heads;
168189
int num_hidden_layers;
169190
int intermediate_size;
170191
float norm_eps;
@@ -419,26 +440,26 @@ class BasicGLU {
419440
class BasicAttention {
420441
public:
421442
BasicAttention() = default;
422-
BasicAttention(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length,
423-
bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta,
424-
AttentionMaskType attn_mask_type, int num_virtual_tokens)
425-
: num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), interleaved_qkv(interleaved_qkv),
426-
rope_type(rope_type), rope_theta(rope_theta), attn_mask_type(attn_mask_type),
427-
num_virtual_tokens(num_virtual_tokens),
428-
query_key_value(mctx, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_kv_heads,
429-
use_qkv_bias),
443+
BasicAttention(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads,
444+
int max_length, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, RopeType rope_type,
445+
float rope_theta, AttentionMaskType attn_mask_type, int num_virtual_tokens)
446+
: num_attention_heads(num_attention_heads), num_key_value_heads(num_key_value_heads),
447+
interleaved_qkv(interleaved_qkv), rope_type(rope_type), rope_theta(rope_theta),
448+
attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens),
449+
query_key_value(mctx, hidden_size,
450+
hidden_size + 2 * (hidden_size / num_attention_heads) * num_key_value_heads, use_qkv_bias),
430451
dense(mctx, hidden_size, hidden_size, use_dense_bias),
431452
k_cache(ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads,
432-
max_length + num_virtual_tokens, num_kv_heads)),
453+
max_length + num_virtual_tokens, num_key_value_heads)),
433454
v_cache(ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens,
434-
hidden_size / num_attention_heads, num_kv_heads)) {}
455+
hidden_size / num_attention_heads, num_key_value_heads)) {}
435456

436457
ggml_tensor *forward(ModelContext *mctx, ggml_tensor *hidden_states, ggml_tensor *attention_mask,
437458
ggml_tensor *position_ids, int n_past) const;
438459

439460
public:
440461
int num_attention_heads;
441-
int num_kv_heads;
462+
int num_key_value_heads;
442463
bool interleaved_qkv;
443464
RopeType rope_type;
444465
float rope_theta;
@@ -454,13 +475,13 @@ template <typename Norm, typename MLP>
454475
class BasicBlock {
455476
public:
456477
BasicBlock() = default;
457-
BasicBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
458-
int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
459-
bool interleaved_qkv, RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type,
460-
int num_virtual_tokens)
478+
BasicBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads,
479+
int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias,
480+
bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta,
481+
AttentionMaskType attn_mask_type, int num_virtual_tokens)
461482
: input_layernorm(mctx, hidden_size, false, norm_eps),
462-
attention(mctx, hidden_size, num_attention_heads, num_kv_heads, max_length, use_qkv_bias, use_dense_bias,
463-
interleaved_qkv, rope_type, rope_theta, attn_mask_type, num_virtual_tokens),
483+
attention(mctx, hidden_size, num_attention_heads, num_key_value_heads, max_length, use_qkv_bias,
484+
use_dense_bias, interleaved_qkv, rope_type, rope_theta, attn_mask_type, num_virtual_tokens),
464485
post_attention_layernorm(mctx, hidden_size, false, norm_eps),
465486
mlp(mctx, hidden_size, intermediate_size, hidden_act) {}
466487

@@ -572,20 +593,20 @@ class BasicModel {
572593
auto &attn = layers[i].attention;
573594
ggml_tensor *virtual_key =
574595
ggml_view_3d(mctx.ctx_b.get(), past_key_values, head_size, config.num_virtual_tokens,
575-
config.num_kv_heads, past_key_values->nb[1], past_key_values->nb[2],
596+
config.num_key_value_heads, past_key_values->nb[1], past_key_values->nb[2],
576597
i * 2 * past_key_values->nb[3]); // [#h, v, d]
577598
ggml_tensor *k_cache_view =
578-
ggml_view_3d(mctx.ctx_b.get(), attn.k_cache, head_size, config.num_virtual_tokens, config.num_kv_heads,
579-
attn.k_cache->nb[1], attn.k_cache->nb[2], 0); // [#h, v, d]
599+
ggml_view_3d(mctx.ctx_b.get(), attn.k_cache, head_size, config.num_virtual_tokens,
600+
config.num_key_value_heads, attn.k_cache->nb[1], attn.k_cache->nb[2], 0); // [#h, v, d]
580601
ggml_build_forward_expand(mctx.gf, ggml_cpy(mctx.ctx_b.get(), virtual_key, k_cache_view));
581602

582603
ggml_tensor *virtual_value = ggml_view_3d(
583-
mctx.ctx_b.get(), past_key_values, head_size, config.num_virtual_tokens, config.num_kv_heads,
604+
mctx.ctx_b.get(), past_key_values, head_size, config.num_virtual_tokens, config.num_key_value_heads,
584605
past_key_values->nb[1], past_key_values->nb[2], (i * 2 + 1) * past_key_values->nb[3]); // [#h, v, d]
585606
virtual_value = ggml_permute(mctx.ctx_b.get(), virtual_value, 1, 0, 2, 3); // [#h, d, v]
586607
ggml_tensor *v_cache_view =
587-
ggml_view_3d(mctx.ctx_b.get(), attn.v_cache, config.num_virtual_tokens, head_size, config.num_kv_heads,
588-
attn.v_cache->nb[1], attn.v_cache->nb[2], 0); // [#h, d, v]
608+
ggml_view_3d(mctx.ctx_b.get(), attn.v_cache, config.num_virtual_tokens, head_size,
609+
config.num_key_value_heads, attn.v_cache->nb[1], attn.v_cache->nb[2], 0); // [#h, d, v]
589610
ggml_build_forward_expand(mctx.gf, ggml_cpy(mctx.ctx_b.get(), virtual_value, v_cache_view));
590611
}
591612

@@ -598,7 +619,7 @@ class BasicModel {
598619
std::vector<Block> layers;
599620
layers.reserve(config.num_hidden_layers);
600621
for (int layer_id = 0; layer_id < config.num_hidden_layers; layer_id++) {
601-
layers.emplace_back(mctx, config.hidden_size, config.num_attention_heads, config.num_kv_heads,
622+
layers.emplace_back(mctx, config.hidden_size, config.num_attention_heads, config.num_key_value_heads,
602623
config.intermediate_size, config.max_length, config.norm_eps, config.hidden_act,
603624
config.use_qkv_bias, config.use_dense_bias, config.interleaved_qkv, config.rope_type,
604625
config.rope_theta, config.attn_mask_type, config.num_virtual_tokens);
@@ -858,10 +879,10 @@ class ChatGLMTokenizer : public BaseTokenizer {
858879
class GLMBlock : public BasicBlock<LayerNorm, BasicMLP> {
859880
public:
860881
GLMBlock() = default;
861-
GLMBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size,
862-
int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias,
863-
bool interleaved_qkv, RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type,
864-
int num_virtual_tokens)
882+
GLMBlock(ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads,
883+
int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias,
884+
bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta,
885+
AttentionMaskType attn_mask_type, int num_virtual_tokens)
865886
: BasicBlock(LayerNorm(mctx, hidden_size, false, norm_eps),
866887
BasicAttention(mctx, hidden_size, num_attention_heads, num_attention_heads, max_length,
867888
use_qkv_bias, use_dense_bias, interleaved_qkv, rope_type, rope_theta,

chatglm_cpp/_C.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class ModelConfig:
104104
def num_hidden_layers(self) -> int:
105105
...
106106
@property
107-
def num_kv_heads(self) -> int:
107+
def num_key_value_heads(self) -> int:
108108
...
109109
@property
110110
def pad_token_id(self) -> int:

0 commit comments

Comments
 (0)