@@ -65,7 +65,7 @@ struct ConfigRecordV1 {
6565
6666// For compatibility
6767struct ConfigRecordV1GQA : public ConfigRecordV1 {
68- int num_kv_heads ;
68+ int num_key_value_heads ;
6969};
7070
7171// TODO: use json to serialize config
@@ -109,15 +109,15 @@ class ModelConfig {
109109 ModelConfig () = default ;
110110
111111 ModelConfig (ModelType model_type, ggml_type dtype, int vocab_size, int hidden_size, int num_attention_heads,
112- int num_kv_heads , int num_hidden_layers, int intermediate_size, float norm_eps, float rope_theta,
112+ int num_key_value_heads , int num_hidden_layers, int intermediate_size, float norm_eps, float rope_theta,
113113 int num_virtual_tokens, int max_length, int bos_token_id, int eos_token_id, int pad_token_id,
114114 int sep_token_id, std::vector<int > extra_eos_token_ids)
115115 : model_type(model_type), dtype(dtype), vocab_size(vocab_size), hidden_size(hidden_size),
116- num_attention_heads (num_attention_heads), num_kv_heads(num_kv_heads), num_hidden_layers(num_hidden_layers ),
117- intermediate_size(intermediate_size ), norm_eps(norm_eps ), rope_theta(rope_theta ),
118- num_virtual_tokens(num_virtual_tokens ), max_length(max_length ), bos_token_id(bos_token_id ),
119- eos_token_id(eos_token_id ), pad_token_id(pad_token_id ), sep_token_id(sep_token_id ),
120- extra_eos_token_ids(std::move(extra_eos_token_ids)) {
116+ num_attention_heads (num_attention_heads), num_key_value_heads(num_key_value_heads ),
117+ num_hidden_layers(num_hidden_layers ), intermediate_size(intermediate_size ), norm_eps(norm_eps ),
118+ rope_theta(rope_theta ), num_virtual_tokens(num_virtual_tokens ), max_length(max_length ),
119+ bos_token_id(bos_token_id ), eos_token_id(eos_token_id ), pad_token_id(pad_token_id ),
120+ sep_token_id(sep_token_id), extra_eos_token_ids(std::move(extra_eos_token_ids)) {
121121 if (model_type == ModelType::CHATGLM) {
122122 hidden_act = ActivationType::GELU;
123123 use_qkv_bias = true ;
@@ -146,9 +146,10 @@ class ModelConfig {
146146
147147 ModelConfig (ModelType model_type, const ConfigRecordV1GQA &rec, float norm_eps, float rope_theta,
148148 int num_virtual_tokens)
149- : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads, rec.num_kv_heads,
150- rec.num_hidden_layers, rec.intermediate_size, norm_eps, rope_theta, num_virtual_tokens,
151- rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id, rec.sep_token_id, {}) {}
149+ : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
150+ rec.num_key_value_heads, rec.num_hidden_layers, rec.intermediate_size, norm_eps, rope_theta,
151+ num_virtual_tokens, rec.max_length, rec.bos_token_id, rec.eos_token_id, rec.pad_token_id,
152+ rec.sep_token_id, {}) {}
152153
153154 ModelConfig (ModelType model_type, const ConfigRecordV2 &rec)
154155 : ModelConfig(model_type, rec.dtype, rec.vocab_size, rec.hidden_size, rec.num_attention_heads,
@@ -158,13 +159,33 @@ class ModelConfig {
158159
159160 std::string model_type_name () const { return to_string (model_type); }
160161
162+ friend std::ostream &operator <<(std::ostream &os, const ModelConfig &self) {
163+ os << " ModelConfig(model_type=" << (int )self.model_type << " , dtype=" << self.dtype
164+ << " , vocab_size=" << self.vocab_size << " , hidden_size=" << self.hidden_size
165+ << " , num_attention_heads=" << self.num_attention_heads
166+ << " , num_key_value_heads=" << self.num_key_value_heads << " , num_hidden_layers=" << self.num_hidden_layers
167+ << " , intermediate_size=" << self.intermediate_size << " , norm_eps=" << self.norm_eps
168+ << " , hidden_act=" << (int )self.hidden_act << " , use_qkv_bias=" << self.use_qkv_bias
169+ << " , use_dense_bias=" << self.use_dense_bias << " , interleaved_qkv=" << self.interleaved_qkv
170+ << " , tie_word_embeddings=" << self.tie_word_embeddings << " , rope_type=" << (int )self.rope_type
171+ << " , rope_theta=" << self.rope_theta << " , attn_mask_type=" << (int )self.attn_mask_type
172+ << " , num_virtual_tokens=" << self.num_virtual_tokens << " , max_length=" << self.max_length
173+ << " , bos_token_id=" << self.bos_token_id << " , eos_token_id=" << self.eos_token_id
174+ << " , pad_token_id=" << self.pad_token_id << " , sep_token_id=" << self.sep_token_id
175+ << " , extra_eos_token_ids={" ;
176+ for (size_t i = 0 ; i < self.extra_eos_token_ids .size (); i++) {
177+ os << (i > 0 ? " , " : " " ) << self.extra_eos_token_ids [i];
178+ }
179+ return os << " })" ;
180+ }
181+
161182 public:
162183 ModelType model_type;
163184 ggml_type dtype;
164185 int vocab_size;
165186 int hidden_size;
166187 int num_attention_heads;
167- int num_kv_heads ;
188+ int num_key_value_heads ;
168189 int num_hidden_layers;
169190 int intermediate_size;
170191 float norm_eps;
@@ -419,26 +440,26 @@ class BasicGLU {
419440class BasicAttention {
420441 public:
421442 BasicAttention () = default ;
422- BasicAttention (ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int max_length ,
423- bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta ,
424- AttentionMaskType attn_mask_type, int num_virtual_tokens)
425- : num_attention_heads(num_attention_heads), num_kv_heads(num_kv_heads), interleaved_qkv(interleaved_qkv ),
426- rope_type (rope_type ), rope_theta(rope_theta ), attn_mask_type(attn_mask_type ),
427- num_virtual_tokens(num_virtual_tokens),
428- query_key_value(mctx, hidden_size, hidden_size + 2 * (hidden_size / num_attention_heads) * num_kv_heads,
429- use_qkv_bias),
443+ BasicAttention (ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads ,
444+ int max_length, bool use_qkv_bias, bool use_dense_bias, bool interleaved_qkv, RopeType rope_type,
445+ float rope_theta, AttentionMaskType attn_mask_type, int num_virtual_tokens)
446+ : num_attention_heads(num_attention_heads), num_key_value_heads(num_key_value_heads ),
447+ interleaved_qkv (interleaved_qkv ), rope_type(rope_type ), rope_theta(rope_theta ),
448+ attn_mask_type(attn_mask_type), num_virtual_tokens(num_virtual_tokens),
449+ query_key_value(mctx, hidden_size,
450+ hidden_size + 2 * (hidden_size / num_attention_heads) * num_key_value_heads, use_qkv_bias),
430451 dense(mctx, hidden_size, hidden_size, use_dense_bias),
431452 k_cache(ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, hidden_size / num_attention_heads,
432- max_length + num_virtual_tokens, num_kv_heads )),
453+ max_length + num_virtual_tokens, num_key_value_heads )),
433454 v_cache(ggml_new_tensor_3d(mctx->ctx_kv.get(), GGML_TYPE_F16, max_length + num_virtual_tokens,
434- hidden_size / num_attention_heads, num_kv_heads )) {}
455+ hidden_size / num_attention_heads, num_key_value_heads )) {}
435456
436457 ggml_tensor *forward (ModelContext *mctx, ggml_tensor *hidden_states, ggml_tensor *attention_mask,
437458 ggml_tensor *position_ids, int n_past) const ;
438459
439460 public:
440461 int num_attention_heads;
441- int num_kv_heads ;
462+ int num_key_value_heads ;
442463 bool interleaved_qkv;
443464 RopeType rope_type;
444465 float rope_theta;
@@ -454,13 +475,13 @@ template <typename Norm, typename MLP>
454475class BasicBlock {
455476 public:
456477 BasicBlock () = default ;
457- BasicBlock (ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size ,
458- int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias ,
459- bool interleaved_qkv, RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type ,
460- int num_virtual_tokens)
478+ BasicBlock (ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads ,
479+ int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias,
480+ bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta,
481+ AttentionMaskType attn_mask_type, int num_virtual_tokens)
461482 : input_layernorm(mctx, hidden_size, false , norm_eps),
462- attention (mctx, hidden_size, num_attention_heads, num_kv_heads , max_length, use_qkv_bias, use_dense_bias ,
463- interleaved_qkv, rope_type, rope_theta, attn_mask_type, num_virtual_tokens),
483+ attention (mctx, hidden_size, num_attention_heads, num_key_value_heads , max_length, use_qkv_bias,
484+ use_dense_bias, interleaved_qkv, rope_type, rope_theta, attn_mask_type, num_virtual_tokens),
464485 post_attention_layernorm(mctx, hidden_size, false , norm_eps),
465486 mlp(mctx, hidden_size, intermediate_size, hidden_act) {}
466487
@@ -572,20 +593,20 @@ class BasicModel {
572593 auto &attn = layers[i].attention ;
573594 ggml_tensor *virtual_key =
574595 ggml_view_3d (mctx.ctx_b .get (), past_key_values, head_size, config.num_virtual_tokens ,
575- config.num_kv_heads , past_key_values->nb [1 ], past_key_values->nb [2 ],
596+ config.num_key_value_heads , past_key_values->nb [1 ], past_key_values->nb [2 ],
576597 i * 2 * past_key_values->nb [3 ]); // [#h, v, d]
577598 ggml_tensor *k_cache_view =
578- ggml_view_3d (mctx.ctx_b .get (), attn.k_cache , head_size, config.num_virtual_tokens , config. num_kv_heads ,
579- attn.k_cache ->nb [1 ], attn.k_cache ->nb [2 ], 0 ); // [#h, v, d]
599+ ggml_view_3d (mctx.ctx_b .get (), attn.k_cache , head_size, config.num_virtual_tokens ,
600+ config. num_key_value_heads , attn.k_cache ->nb [1 ], attn.k_cache ->nb [2 ], 0 ); // [#h, v, d]
580601 ggml_build_forward_expand (mctx.gf , ggml_cpy (mctx.ctx_b .get (), virtual_key, k_cache_view));
581602
582603 ggml_tensor *virtual_value = ggml_view_3d (
583- mctx.ctx_b .get (), past_key_values, head_size, config.num_virtual_tokens , config.num_kv_heads ,
604+ mctx.ctx_b .get (), past_key_values, head_size, config.num_virtual_tokens , config.num_key_value_heads ,
584605 past_key_values->nb [1 ], past_key_values->nb [2 ], (i * 2 + 1 ) * past_key_values->nb [3 ]); // [#h, v, d]
585606 virtual_value = ggml_permute (mctx.ctx_b .get (), virtual_value, 1 , 0 , 2 , 3 ); // [#h, d, v]
586607 ggml_tensor *v_cache_view =
587- ggml_view_3d (mctx.ctx_b .get (), attn.v_cache , config.num_virtual_tokens , head_size, config. num_kv_heads ,
588- attn.v_cache ->nb [1 ], attn.v_cache ->nb [2 ], 0 ); // [#h, d, v]
608+ ggml_view_3d (mctx.ctx_b .get (), attn.v_cache , config.num_virtual_tokens , head_size,
609+ config. num_key_value_heads , attn.v_cache ->nb [1 ], attn.v_cache ->nb [2 ], 0 ); // [#h, d, v]
589610 ggml_build_forward_expand (mctx.gf , ggml_cpy (mctx.ctx_b .get (), virtual_value, v_cache_view));
590611 }
591612
@@ -598,7 +619,7 @@ class BasicModel {
598619 std::vector<Block> layers;
599620 layers.reserve (config.num_hidden_layers );
600621 for (int layer_id = 0 ; layer_id < config.num_hidden_layers ; layer_id++) {
601- layers.emplace_back (mctx, config.hidden_size , config.num_attention_heads , config.num_kv_heads ,
622+ layers.emplace_back (mctx, config.hidden_size , config.num_attention_heads , config.num_key_value_heads ,
602623 config.intermediate_size , config.max_length , config.norm_eps , config.hidden_act ,
603624 config.use_qkv_bias , config.use_dense_bias , config.interleaved_qkv , config.rope_type ,
604625 config.rope_theta , config.attn_mask_type , config.num_virtual_tokens );
@@ -858,10 +879,10 @@ class ChatGLMTokenizer : public BaseTokenizer {
858879class GLMBlock : public BasicBlock <LayerNorm, BasicMLP> {
859880 public:
860881 GLMBlock () = default ;
861- GLMBlock (ModelContext *mctx, int hidden_size, int num_attention_heads, int num_kv_heads, int intermediate_size ,
862- int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias, bool use_dense_bias ,
863- bool interleaved_qkv, RopeType rope_type, float rope_theta, AttentionMaskType attn_mask_type ,
864- int num_virtual_tokens)
882+ GLMBlock (ModelContext *mctx, int hidden_size, int num_attention_heads, int num_key_value_heads ,
883+ int intermediate_size, int max_length, float norm_eps, ActivationType hidden_act, bool use_qkv_bias,
884+ bool use_dense_bias, bool interleaved_qkv, RopeType rope_type, float rope_theta,
885+ AttentionMaskType attn_mask_type, int num_virtual_tokens)
865886 : BasicBlock(LayerNorm(mctx, hidden_size, false , norm_eps),
866887 BasicAttention (mctx, hidden_size, num_attention_heads, num_attention_heads, max_length,
867888 use_qkv_bias, use_dense_bias, interleaved_qkv, rope_type, rope_theta,
0 commit comments