diff --git a/csrc/models/gpt2/gpt2_for_causal_lm.cpp b/csrc/models/gpt2/gpt2_for_causal_lm.cpp new file mode 100644 index 00000000..aaab6786 --- /dev/null +++ b/csrc/models/gpt2/gpt2_for_causal_lm.cpp @@ -0,0 +1,261 @@ +#include "gpt2_for_causal_lm.hpp" +#include "../../global_state/global_state.hpp" +#include "../../layers/attention/attention.hpp" +#include "../models_registry.hpp" +#include +#include + +namespace infinilm::models::gpt2 { + +std::shared_ptr +create_gpt2_model_config(std::shared_ptr config) { + const std::string &model_type = config->get("model_type"); + if ("gpt2" != model_type) { + throw std::runtime_error( + "infinilm::models::gpt2::create_gpt2_model_config: model_type is not gpt2"); + } + + auto &j = config->get_config_json(); + + j["hidden_size"] = j.value("hidden_size", j.value("n_embd", 768)); + j["num_hidden_layers"] = j.value("num_hidden_layers", j.value("n_layer", 12)); + j["num_attention_heads"] = j.value("num_attention_heads", j.value("n_head", 12)); + j["num_key_value_heads"] = j["num_attention_heads"]; + j["head_dim"] = j["hidden_size"].get() / j["num_attention_heads"].get(); + j["max_position_embeddings"] = j.value("max_position_embeddings", j.value("n_positions", 1024)); + j["intermediate_size"] = j.value("n_inner", 4 * j["hidden_size"].get()); + j["layer_norm_eps"] = j.value("layer_norm_epsilon", 1e-5); + j["attention_bias"] = true; + j["attention_output_bias"] = true; + j["mlp_bias"] = true; + + return config; +} + +GPT2Attention::GPT2Attention(std::shared_ptr config, + size_t layer_idx, + const infinicore::Device &device) + : layer_idx_(layer_idx) { + const auto &dtype = config->get_dtype(); + hidden_size_ = config->get("hidden_size"); + num_heads_ = config->get("num_attention_heads"); + num_kv_heads_ = config->get("num_key_value_heads"); + head_dim_ = config->get("head_dim"); + + const bool use_bias = config->get_or("attention_bias", true); + const bool use_output_bias = config->get_or("attention_output_bias", true); + auto quantization_method = config->get_quantization_method(); + const auto &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + const int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); + const int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); + const size_t total_num_heads = num_heads_; + const size_t total_num_kv_heads = num_kv_heads_; + + num_heads_ = total_num_heads / tp_size; + num_kv_heads_ = total_num_kv_heads < static_cast(tp_size) + ? 1 + : total_num_kv_heads / tp_size; + + auto register_fn = [this](const std::string &name, infinicore::nn::Parameter param) { + this->register_parameter(name, std::move(param)); + }; + qkv_proj_ = std::make_shared( + hidden_size_, + head_dim_, + total_num_heads, + total_num_kv_heads, + "q_proj", + "k_proj", + "v_proj", + register_fn, + quantization_method, + use_bias, + dtype, + device, + rank_info); + INFINICORE_NN_MODULE_INIT( + o_proj, + total_num_heads * head_dim_, + hidden_size_, + quantization_method, + use_output_bias, + dtype, + device, + tp_rank, + tp_size, + rank_info.comm); + + infinilm::layers::attention::init_kv_cache_quant_params( + register_fn, device, kv_cache_k_scale_, kv_cache_v_scale_); + + const float scale = 1.0f / std::sqrt(static_cast(head_dim_)); + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + attn_ = std::make_shared( + num_heads_, + head_dim_, + scale, + num_kv_heads_, + layer_idx_, + kv_cache_k_scale_, + kv_cache_v_scale_, + attention_backend_); +} + +infinicore::Tensor GPT2Attention::forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const { + (void)positions; + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + if (attention_backend_ == infinilm::backends::AttentionBackend::PAGED_ATTN + || attention_backend_ == infinilm::backends::AttentionBackend::FLASH_ATTN) { + auto q_reshaped = q->view({seq_len, num_heads_, head_dim_}); + auto k_reshaped = k->view({seq_len, num_kv_heads_, head_dim_}); + auto v_reshaped = v->view({seq_len, num_kv_heads_, head_dim_}); + auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); + return o_proj_->forward(attn_output); + } + + auto q_reshaped = q->view({batch_size, seq_len, num_heads_, head_dim_}); + auto k_reshaped = k->view({batch_size, seq_len, num_kv_heads_, head_dim_}); + auto v_reshaped = v->view({batch_size, seq_len, num_kv_heads_, head_dim_}); + auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped); + return o_proj_->forward(attn_output); +} + +GPT2MLP::GPT2MLP(std::shared_ptr config, + const infinicore::Device &device) { + const auto &dtype = config->get_dtype(); + const size_t hidden_size = config->get("hidden_size"); + const size_t intermediate_size = config->get("intermediate_size"); + const bool use_bias = config->get_or("mlp_bias", true); + activation_ = config->get_or("activation_function", "gelu_new"); + auto quantization_method = config->get_quantization_method(); + const auto &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + + INFINICORE_NN_MODULE_INIT( + c_fc, + hidden_size, + intermediate_size, + quantization_method, + use_bias, + dtype, + device, + rank_info.tp_rank, + rank_info.tp_size); + INFINICORE_NN_MODULE_INIT( + c_proj, + intermediate_size, + hidden_size, + quantization_method, + use_bias, + dtype, + device, + rank_info.tp_rank, + rank_info.tp_size, + rank_info.comm); +} + +infinicore::Tensor GPT2MLP::forward(const infinicore::Tensor &hidden_states) const { + auto x = const_cast(hidden_states); + x = c_fc_->forward(x); + if (activation_ == "gelu_new" || activation_ == "gelu_tanh") { + x = infinicore::op::gelu_tanh(x); + } else if (activation_ == "gelu") { + x = infinicore::op::gelu(x); + } else { + throw std::runtime_error("infinilm::models::gpt2::GPT2MLP: unsupported activation " + activation_); + } + return c_proj_->forward(x); +} + +GPT2Block::GPT2Block(std::shared_ptr config, + size_t layer_idx, + const infinicore::Device &device) { + const auto &dtype = config->get_dtype(); + const size_t hidden_size = config->get("hidden_size"); + const double layer_norm_eps = config->get("layer_norm_eps"); + + INFINICORE_NN_MODULE_INIT(ln_1, hidden_size, layer_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(attn, config, layer_idx, device); + INFINICORE_NN_MODULE_INIT(ln_2, hidden_size, layer_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(mlp, config, device); +} + +infinicore::Tensor GPT2Block::forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const { + auto residual = hidden_states; + auto x = ln_1_->forward(hidden_states); + x = attn_->forward(positions, x); + x = infinicore::op::add(x, residual); + + residual = x; + x = ln_2_->forward(x); + x = mlp_->forward(x); + return infinicore::op::add(x, residual); +} + +GPT2Model::GPT2Model(std::shared_ptr config, + const infinicore::Device &device) { + const auto &dtype = config->get_dtype(); + const size_t vocab_size = config->get("vocab_size"); + const size_t hidden_size = config->get("hidden_size"); + const size_t max_position_embeddings = config->get("max_position_embeddings"); + const size_t num_hidden_layers = config->get("num_hidden_layers"); + const double layer_norm_eps = config->get("layer_norm_eps"); + + INFINICORE_NN_MODULE_INIT(embed_tokens, vocab_size, hidden_size, std::nullopt, dtype, device); + INFINICORE_NN_MODULE_INIT(embed_positions, max_position_embeddings, hidden_size, std::nullopt, dtype, device); + layers_.reserve(num_hidden_layers); + for (size_t i = 0; i < num_hidden_layers; ++i) { + layers_.push_back(this->register_module("layers." + std::to_string(i), config, i, device)); + } + INFINICORE_NN_MODULE_INIT(norm, hidden_size, layer_norm_eps, dtype, device); +} + +infinicore::Tensor GPT2Model::forward(const infinilm::InfinilmModel::Input &input) const { + auto input_ids = input.input_ids.value(); + auto position_ids = input.position_ids.value(); + + auto hidden_states = infinicore::op::add( + embed_tokens_->forward(input_ids), + embed_positions_->forward(position_ids)); + + for (const auto &layer : layers_) { + hidden_states = layer->forward(position_ids, hidden_states); + } + + return norm_->forward(hidden_states); +} + +GPT2ForCausalLM::GPT2ForCausalLM(std::shared_ptr config, + const infinicore::Device &device) { + model_config_ = config; + const auto &dtype = config->get_dtype(); + const size_t hidden_size = config->get("hidden_size"); + const size_t vocab_size = config->get("vocab_size"); + + INFINICORE_NN_MODULE_INIT(model, config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); +} + +InfinilmModel::Output GPT2ForCausalLM::forward(const InfinilmModel::Input &input) const { + auto hidden_states = model_->forward(input); + auto logits = lm_head_->forward(hidden_states); + return {logits}; +} + +} // namespace infinilm::models::gpt2 + +namespace { + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + gpt2, + infinilm::models::gpt2::GPT2ForCausalLM, + infinilm::models::gpt2::create_gpt2_model_config); + +} // namespace diff --git a/csrc/models/gpt2/gpt2_for_causal_lm.hpp b/csrc/models/gpt2/gpt2_for_causal_lm.hpp new file mode 100644 index 00000000..26553fac --- /dev/null +++ b/csrc/models/gpt2/gpt2_for_causal_lm.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include "../../models/infinilm_model.hpp" +#include "../../backends/attention_backends.hpp" +#include "../../layers/linear/linear.hpp" +#include "../../layers/linear/fused_linear.hpp" +#include "../../layers/attention/backends/attention_layer.hpp" +#include "../../config/model_config.hpp" +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/layer_norm.hpp" +#include "infinicore/nn/parameter.hpp" +#include "infinicore/ops.hpp" +#include +#include + +namespace infinilm::models::gpt2 { + +class GPT2Attention : public infinicore::nn::Module { +public: + GPT2Attention(std::shared_ptr config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + +private: + std::shared_ptr qkv_proj_; + std::shared_ptr attn_; + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); + + size_t layer_idx_; + size_t hidden_size_; + size_t num_heads_; + size_t num_kv_heads_; + size_t head_dim_; + infinilm::backends::AttentionBackend attention_backend_; +}; + +class GPT2MLP : public infinicore::nn::Module { +public: + GPT2MLP(std::shared_ptr config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, c_fc); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, c_proj); + std::string activation_; +}; + +class GPT2Block : public infinicore::nn::Module { +public: + GPT2Block(std::shared_ptr config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &positions, + const infinicore::Tensor &hidden_states) const; + +private: + INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln_1); + INFINICORE_NN_MODULE(GPT2Attention, attn); + INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, ln_2); + INFINICORE_NN_MODULE(GPT2MLP, mlp); +}; + +class GPT2Model : public infinicore::nn::Module { +public: + GPT2Model(std::shared_ptr config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const; + +private: + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_positions); + INFINICORE_NN_MODULE_VEC(GPT2Block, layers); + INFINICORE_NN_MODULE(infinicore::nn::LayerNorm, norm); +}; + +class GPT2ForCausalLM : public infinilm::InfinilmModel { +public: + GPT2ForCausalLM(std::shared_ptr config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + +private: + INFINICORE_NN_MODULE(GPT2Model, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr +create_gpt2_model_config(std::shared_ptr config); + +} // namespace infinilm::models::gpt2 diff --git a/examples/test_infer.py b/examples/test_infer.py index a3ce5e6d..28590742 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -18,6 +18,7 @@ def test( attn_backend="default", image_path=None, skip_load=False, + dtype="float16", ): model_path = os.path.expanduser(model_path) # ---------------------------------------------------------------------------- # @@ -36,6 +37,7 @@ def test( temperature=temperature, top_k=top_k, top_p=top_p, + dtype=dtype, enable_graph=enable_graph, attn_backend=attn_backend, skip_load=skip_load, @@ -54,9 +56,14 @@ def test( t1 = time.time() print("=================== start generate ====================") - outputs = model.chat( - messages=conversations, - ) + if getattr(model.engine.tokenizer, "chat_template", None): + outputs = model.chat( + messages=conversations, + ) + else: + outputs = model.generate( + prompts=prompts, + ) t2 = time.time() for i, output in enumerate(outputs): @@ -103,4 +110,5 @@ def test( attn_backend=cfg.attn, image_path=cfg.image, skip_load=cfg.skip_load, + dtype=cfg.dtype, ) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 31d65de9..9b49a216 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -54,6 +54,7 @@ def __init__( self, model_path, device=None, + dtype="float16", distributed_config=DistConfig(1), cache_config=None, enable_graph_compiling=False, @@ -61,6 +62,8 @@ def __init__( kv_cache_dtype=None, ): self.hf_config = read_hf_config(model_path) + if self.hf_config.get("torch_dtype") is None and self.hf_config.get("dtype") is None: + self.hf_config["torch_dtype"] = dtype self.hf_generation_config = read_hf_generation_config(model_path) if device is None: diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index cba3af83..ebd4c316 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -89,6 +89,7 @@ def __init__(self, config: EngineConfig): self.model_engine = InferEngine( model_path=config.model_path, device=self.device, + dtype=config.dtype, distributed_config=DistConfig(config.tensor_parallel_size), enable_graph_compiling=config.enable_graph, attention_backend=config.attn_backend, diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 4d2bedb4..222e8d7e 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -105,7 +105,7 @@ def load_state_dict( ) for k in f.keys(): - state_dict[k] = f.get_tensor(k).to(device=device) + state_dict[k] = f.get_tensor(k).to(device=device, dtype=dtype) return state_dict @@ -567,9 +567,65 @@ def _remap_baichuan(state_dict, config=None): return state_dict +def _remap_gpt2(state_dict, config=None): + """Remap HuggingFace GPT-2 weights to InfiniLM GPT-2 module names. + + HuggingFace GPT-2 uses Conv1D modules whose weights are stored as + [in_features, out_features]. InfiniLM Linear expects [out_features, + in_features], so projection weights must be transposed. + """ + renamed = {} + for key, tensor in state_dict.items(): + if key.endswith((".attn.bias", ".attn.masked_bias")): + continue + + new_key = key + new_key = new_key.replace("transformer.wte", "model.embed_tokens") + new_key = new_key.replace("transformer.wpe", "model.embed_positions") + new_key = new_key.replace("transformer.h.", "model.layers.") + new_key = new_key.replace(".attn.c_proj.", ".attn.o_proj.") + new_key = new_key.replace("transformer.ln_f", "model.norm") + + if new_key.startswith("wte."): + new_key = new_key.replace("wte.", "model.embed_tokens.", 1) + elif new_key.startswith("wpe."): + new_key = new_key.replace("wpe.", "model.embed_positions.", 1) + elif new_key.startswith("h."): + new_key = new_key.replace("h.", "model.layers.", 1) + elif new_key.startswith("ln_f."): + new_key = new_key.replace("ln_f.", "model.norm.", 1) + + if key.endswith(("attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight")): + tensor = tensor.t().contiguous() + renamed[new_key] = tensor + + remapped = {} + for key, tensor in renamed.items(): + if key.endswith(".attn.c_attn.weight"): + q, k, v = tensor.t().contiguous().chunk(3, dim=0) + prefix = key.removesuffix(".attn.c_attn.weight") + remapped[f"{prefix}.attn.q_proj.weight"] = q + remapped[f"{prefix}.attn.k_proj.weight"] = k + remapped[f"{prefix}.attn.v_proj.weight"] = v + elif key.endswith(".attn.c_attn.bias"): + q, k, v = tensor.chunk(3, dim=0) + prefix = key.removesuffix(".attn.c_attn.bias") + remapped[f"{prefix}.attn.q_proj.bias"] = q + remapped[f"{prefix}.attn.k_proj.bias"] = k + remapped[f"{prefix}.attn.v_proj.bias"] = v + else: + remapped[key] = tensor + + if "lm_head.weight" not in remapped and "model.embed_tokens.weight" in remapped: + remapped["lm_head.weight"] = remapped["model.embed_tokens.weight"] + + return remapped + + # Model type → remap function mapping _WEIGHT_REMAPPER = { "glm4": _remap_glm4, "chatglm": _remap_chatglm, "baichuan": _remap_baichuan, + "gpt2": _remap_gpt2, }