diff --git a/example/gpt2/config.h b/example/gpt2/config.h index 078f9fd5..71cc0a56 100644 --- a/example/gpt2/config.h +++ b/example/gpt2/config.h @@ -14,7 +14,7 @@ inline nn::TransformerConfig GPT2Config() { .n_head = 12, .n_kv_head = 12, .n_embd = 768, - .attention_type = nn::AttentionType::kStandard, + .position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute, .activation_type = nn::MLPType::kGELU, .norm_type = nn::NormType::kLayerNorm, .add_bias_linear = true, @@ -34,7 +34,8 @@ inline void SanitizeGPT2Config(const nn::TransformerConfig &c) { CHECK_GT(c.n_embd, 0); CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head"; CHECK_EQ(c.n_kv_head, c.n_head) << "GPT-2 does not use GQA; n_kv_head must equal n_head"; - CHECK(c.attention_type == nn::AttentionType::kStandard) << "GPT-2 requires standard attention"; + CHECK(c.position_embedding_type == nn::PositionEmbeddingType::kLearnedAbsolute) + << "GPT-2 requires learned absolute position embedding"; CHECK(c.activation_type == nn::MLPType::kGELU) << "GPT-2 requires GELU activation"; CHECK(c.norm_type == nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm"; } diff --git a/example/llama3/config.h b/example/llama3/config.h index 6bc9124d..67ebd31f 100644 --- a/example/llama3/config.h +++ b/example/llama3/config.h @@ -14,7 +14,7 @@ inline nn::TransformerConfig LLaMA3Config() { .n_head = 32, .n_kv_head = 8, .n_embd = 2048, - .attention_type = nn::AttentionType::kRoPE, + .position_embedding_type = nn::PositionEmbeddingType::kRoPE, .activation_type = nn::MLPType::kSwiGLU, .norm_type = nn::NormType::kRMSNorm, .add_bias_linear = false, @@ -36,7 +36,7 @@ inline void SanitizeLLaMA3Config(const nn::TransformerConfig &c) { CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA"; CHECK_GT(c.n_embd, 0); CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head"; - CHECK(c.attention_type == nn::AttentionType::kRoPE) << "LLaMA-3 requires RoPE attention"; + CHECK(c.position_embedding_type == nn::PositionEmbeddingType::kRoPE) << "LLaMA-3 requires RoPE position embedding"; CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "LLaMA-3 requires SwiGLU activation"; CHECK(c.norm_type == nn::NormType::kRMSNorm) << "LLaMA-3 requires RMSNorm"; CHECK(!c.add_bias_linear) << "LLaMA-3 has no bias in linear layers"; diff --git a/infini_train/include/nn/modules/transformer/causal_self_attention.h b/infini_train/include/nn/modules/transformer/causal_self_attention.h index 5ac55e31..60373414 100644 --- a/infini_train/include/nn/modules/transformer/causal_self_attention.h +++ b/infini_train/include/nn/modules/transformer/causal_self_attention.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include "infini_train/include/nn/modules/module.h" @@ -35,20 +34,6 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule> - ForwardStandard(const std::vector> &x); - - // RoPE-aware attention forward (LLaMA3 style: with RoPE, optional GQA) - std::vector> - ForwardWithRoPE(const std::vector> &x); - - // RoPE helper methods - std::tuple, std::shared_ptr> - ApplyRotaryEmbedding(const std::shared_ptr &xq, - const std::shared_ptr &xk, - const std::shared_ptr &freqs_cis); - // GQA helper method std::shared_ptr RepeatKV(const std::shared_ptr &x, int64_t n_rep); }; diff --git a/infini_train/include/nn/modules/transformer/mla_self_attention.h b/infini_train/include/nn/modules/transformer/mla_self_attention.h new file mode 100644 index 00000000..63177cc6 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/mla_self_attention.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn { + +class MLASelfAttention : public infini_train::nn::CloneableModule { +public: + static constexpr char kType[] = "MLASelfAttention"; + + static constexpr char kLinearQProjLayerName[] = "linear_q_proj"; + static constexpr char kLinearQDownProjLayerName[] = "linear_q_down_proj"; + static constexpr char kQLayerNormLayerName[] = "q_layernorm"; + static constexpr char kLinearQUpProjLayerName[] = "linear_q_up_proj"; + static constexpr char kLinearKVDownProjLayerName[] = "linear_kv_down_proj"; + static constexpr char kKVLayerNormLayerName[] = "kv_layernorm"; + static constexpr char kLinearKVUpProjLayerName[] = "linear_kv_up_proj"; + static constexpr char kLinearProjLayerName[] = "linear_proj"; + + static constexpr char kParamBiasName[] = "bias"; + + explicit MLASelfAttention(const TransformerConfig &config); + + std::vector> + Forward(const std::vector> &x) override; + +private: + TransformerConfig config_; + int64_t n_head_ = 0; + int64_t n_embd_ = 0; + int64_t local_n_head_ = 0; + + int64_t q_lora_rank_ = 0; + int64_t kv_lora_rank_ = 0; + int64_t qk_nope_head_dim_ = 0; + int64_t qk_rope_head_dim_ = 0; + int64_t qk_head_dim_ = 0; + int64_t v_head_dim_ = 0; + + bool use_q_lora_ = true; + bool q_down_proj_use_tp_ = false; + bool kv_down_proj_use_tp_ = false; + + void SetupAttention(const TransformerConfig &config); +}; + +} // namespace infini_train::nn diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 62379666..5a440e60 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -10,9 +10,13 @@ enum class ModelType { kLLaMA3, // LLaMA3 }; -enum class AttentionType { - kStandard, // Standard attention - kRoPE // Rotary Position Embedding +enum class PositionEmbeddingType { + kLearnedAbsolute, // Megatron: learned_absolute + kRoPE, // Megatron: rope + kYarn, // Megatron: yarn + kMRoPE, // Megatron: mrope + kRelative, // Megatron: relative + kNone // Megatron: none }; enum class MLPType { @@ -34,9 +38,9 @@ struct TransformerConfig { int64_t n_kv_head = 12; // Num of Key/Value heads (<= n_head, < n_head if using GQA) int64_t n_embd = 768; // Hidden size - AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type - MLPType activation_type = MLPType::kGELU; // MLP activation type - NormType norm_type = NormType::kLayerNorm; // Normalization type + PositionEmbeddingType position_embedding_type = PositionEmbeddingType::kLearnedAbsolute; // Position embedding type. + MLPType activation_type = MLPType::kGELU; // MLP activation type + NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, // including: attention QKV projection, attention output projection, MLP FC layers (and @@ -53,6 +57,16 @@ struct TransformerConfig { float rope_theta = 500000.0f; // theta in RoPE bool use_scaled_rope = false; // scaled RoPE + // MLA config + bool multi_latent_attention = false; // Use MLA instead of standard causal self-attention. + std::optional q_lora_rank = std::nullopt; // nullopt means direct linear_q_proj path. + int64_t kv_lora_rank = 0; // 0 falls back to n_embd in MLASelfAttention. + int64_t qk_nope_head_dim = 0; // 0 falls back to n_embd / n_head. + int64_t qk_rope_head_dim = 0; // 0 falls back to n_embd / n_head. + int64_t v_head_dim = 0; // 0 falls back to n_embd / n_head. + bool q_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_q_down_proj. + bool kv_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_kv_down_proj. + // Normalization float norm_eps = 1e-5f; // epsilon in RMSNorm diff --git a/infini_train/include/nn/modules/transformer/utils.h b/infini_train/include/nn/modules/transformer/utils.h index d3a62c63..30db08e6 100644 --- a/infini_train/include/nn/modules/transformer/utils.h +++ b/infini_train/include/nn/modules/transformer/utils.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include "infini_train/include/tensor.h" @@ -8,4 +10,8 @@ namespace infini_train { // RoPE helper method std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false, Device device = Device()); + +std::tuple, std::shared_ptr> +ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis); } // namespace infini_train diff --git a/infini_train/include/nn/parallel/utils.h b/infini_train/include/nn/parallel/utils.h index 4dc737e7..2cd09ce7 100644 --- a/infini_train/include/nn/parallel/utils.h +++ b/infini_train/include/nn/parallel/utils.h @@ -23,6 +23,7 @@ std::vector GetPipelineParallelGroupRanks(int global_rank); // TP/SP Communication Helper Functions std::vector> GatherFromTPRegionFunc(const std::shared_ptr &input); +std::vector> ScatterToSPRegionFunc(const std::shared_ptr &input); std::vector> ReduceScatterToSPRegionFunc(const std::shared_ptr &input); std::vector> GatherFromSPRegionFunc(const std::shared_ptr &input); std::vector> ScatterToTPRegionFunc(const std::shared_ptr &input); diff --git a/infini_train/src/nn/modules/transformer/causal_self_attention.cc b/infini_train/src/nn/modules/transformer/causal_self_attention.cc index 5ea9eec5..43e4ca51 100644 --- a/infini_train/src/nn/modules/transformer/causal_self_attention.cc +++ b/infini_train/src/nn/modules/transformer/causal_self_attention.cc @@ -1,6 +1,7 @@ #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include +#include #include #include #include @@ -12,6 +13,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -42,12 +44,9 @@ CausalSelfAttention::CausalSelfAttention(const TransformerConfig &config) : Clon /*skip_bias_add=*/false, /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); - // For standard attention (GPT2 style), precompute causal mask - if (config_.attention_type == AttentionType::kStandard) { - // causal mask: (1, 1, block_size, block_size) - buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) - ->View({1, 1, config_.block_size, config_.block_size}); - } + // causal mask: (1, 1, block_size, block_size) + buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) + ->View({1, 1, config_.block_size, config_.block_size}); } void CausalSelfAttention::SetupAttention(const TransformerConfig &config) { @@ -76,125 +75,21 @@ void CausalSelfAttention::SetupAttention(const TransformerConfig &config) { std::vector> CausalSelfAttention::Forward(const std::vector> &x) { - if (config_.attention_type == AttentionType::kRoPE) { - return ForwardWithRoPE(x); - } else { - return ForwardStandard(x); - } -} - -std::vector> -CausalSelfAttention::ForwardStandard(const std::vector> &x) { - auto tp_world_size = parallel::global::GetTensorParallelSize(); - - const auto B = x[0]->Dims()[0]; // bs - const auto C = x[0]->Dims()[2]; // n_embd - const int64_t head_dim = n_embd_ / n_head_; // per-head dim (global) - const int64_t local_C = n_embd_ / tp_world_size; // per-rank hidden - - // (B, T, C) -> ColumnParallelLinear(C, 3*C) -> (B, T, 3 * local_C) - // -> Split -> (3, B, T, local_C) - auto qkv = (*modules_[kCAttnLayerName])(x)[0]->Split(local_C, 2); - - // (B, T, local_C) - auto q = qkv[0]; - auto k = qkv[1]; - auto v = qkv[2]; - - // NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear - const auto T = q->Dims()[1]; - - // View to multi-head: local_n_head * head_dim == local_C - // (B, T, local_C) -> (B, T, h_l, Dh) -> (B, h_l, T, Dh) - k = k->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2); - - // (B, h_l, T, T) - auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim)); - // (1, 1, T, T) - auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); - // (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T) - att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); - // (B, h_l, T, T) - att = nn::function::Softmax(att, -1); - // (B, h_l, T, Dh) - auto y = att->Matmul(v); - // (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C) - y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C}); - - // Get full tensor - // (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C) - y = (*modules_[kCProjLayerName])({y})[0]; - // (B, T, C) == (bs, seq_len, n_embd) - return {y}; -} - -// RoPE helper methods -std::tuple, std::shared_ptr> -CausalSelfAttention::ApplyRotaryEmbedding(const std::shared_ptr &xq, - const std::shared_ptr &xk, - const std::shared_ptr &freqs_cis) { - // Reshape freqs_cis for broadcasting - const auto &x_shape = xq->Dims(); // (B, T, H, D) - const int64_t T = x_shape[1]; - const int64_t D = x_shape[3]; - - std::vector target_shape = {1, T, 1, D / 2, 2}; - auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2) - - auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) - auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) - - auto slice_pair = [](const std::shared_ptr &x) { - auto even = x->Slice(-1, 0, x->Dims().back(), 2); - auto odd = x->Slice(-1, 1, x->Dims().back(), 2); - return std::make_pair(even, odd); - }; - - auto [q_even, q_odd] = slice_pair(xq); - auto q_rotated_left = q_even * cos - q_odd * sin; - auto q_rotated_right = q_even * sin + q_odd * cos; - auto q_rotated - = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); - - auto [k_even, k_odd] = slice_pair(xk); - auto k_rotated_left = k_even * cos - k_odd * sin; - auto k_rotated_right = k_even * sin + k_odd * cos; - auto k_rotated - = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); - - return {q_rotated, k_rotated}; -} - -std::shared_ptr CausalSelfAttention::RepeatKV(const std::shared_ptr &x, - int64_t n_rep) { - const auto &shape = x->Dims(); - const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3]; - - if (n_rep == 1) { - return x; - } - - return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D}); -} - -std::vector> -CausalSelfAttention::ForwardWithRoPE(const std::vector> &x) { const auto B = x[0]->Dims()[0]; // bs const auto C = x[0]->Dims()[2]; // n_embd const auto tp_size = nn::parallel::global::GetTensorParallelSize(); const auto C_local = C / tp_size; - const auto H_local = n_head_ / tp_size; + const auto H_local = local_n_head_; const auto KV_local = n_kv_head_ / tp_size; const auto D = head_dim_; // n_embd / n_head const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; - const auto start_pos = x.size() > 2 ? x[2] : nullptr; const auto mask = x.size() > 3 ? x[3] : nullptr; - CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { + CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + } // (B, T, C) -> (B, T, (H + 2 * n_kv_head) * D) auto qkv = (*modules_[kCAttnLayerName])({x[0]})[0]; @@ -212,10 +107,10 @@ CausalSelfAttention::ForwardWithRoPE(const std::vectorSlice(2, q_size_local + kv_size_local, q_size_local + 2 * kv_size_local)->View({B, T, KV_local, D}); - // -> RoPE on q, k - // q: (B, T, H_local, D) - // k: (B, T, KV_local, D) - std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis); + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { + // q: (B, T, H_local, D), k: (B, T, KV_local, D) + std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis); + } // TODO(zbl): use kv cache during inference // if (use_kv_) { ... } @@ -243,6 +138,10 @@ CausalSelfAttention::ForwardWithRoPE(const std::vectorMaskedFill(mask, std::numeric_limits::lowest()); + } else { + // fallback causal mask: (1, 1, T, T) + auto causal_mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + att = att->MaskedFill(causal_mask == 0, -std::numeric_limits::infinity()); } // (B, H_local, T, T) att = nn::function::Softmax(att, -1); @@ -257,4 +156,16 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector CausalSelfAttention::RepeatKV(const std::shared_ptr &x, + int64_t n_rep) { + const auto &shape = x->Dims(); + const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3]; + + if (n_rep == 1) { + return x; + } + + return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D}); +} + } // namespace infini_train::nn diff --git a/infini_train/src/nn/modules/transformer/mla_self_attention.cc b/infini_train/src/nn/modules/transformer/mla_self_attention.cc new file mode 100644 index 00000000..536d077d --- /dev/null +++ b/infini_train/src/nn/modules/transformer/mla_self_attention.cc @@ -0,0 +1,270 @@ +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/modules/linear.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn { + +MLASelfAttention::MLASelfAttention(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + SetupAttention(config); + + if (use_q_lora_) { + if (q_down_proj_use_tp_) { + modules_[kLinearQDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearQDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/q_lora_rank_, + /*bias=*/config_.add_bias_linear); + } + modules_[kQLayerNormLayerName] = std::make_shared(q_lora_rank_, config_.norm_eps); + modules_[kLinearQUpProjLayerName] = std::make_shared( + /*in_features=*/q_lora_rank_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearQProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/n_head_ * qk_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } + + if (kv_down_proj_use_tp_) { + modules_[kLinearKVDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + } else { + modules_[kLinearKVDownProjLayerName] = std::make_shared( + /*in_features=*/n_embd_, + /*out_features=*/kv_lora_rank_ + qk_rope_head_dim_, + /*bias=*/config_.add_bias_linear); + } + modules_[kKVLayerNormLayerName] = std::make_shared(kv_lora_rank_, config_.norm_eps); + modules_[kLinearKVUpProjLayerName] = std::make_shared( + /*in_features=*/kv_lora_rank_, + /*out_features=*/n_head_ * (qk_nope_head_dim_ + v_head_dim_), + /*bias=*/config_.add_bias_linear, + /*gather_output=*/false, + /*input_is_parallel=*/false, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + modules_[kLinearProjLayerName] = std::make_shared( + /*in_features=*/n_head_ * v_head_dim_, + /*out_features=*/n_embd_, + /*bias=*/config_.add_bias_linear, + /*reduce_output=*/true, + /*input_is_parallel=*/true, + /*skip_bias_add=*/false, + /*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled()); + + buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size})) + ->View({1, 1, config_.block_size, config_.block_size}); +} + +void MLASelfAttention::SetupAttention(const TransformerConfig &config) { + auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); + + CHECK_EQ(config.n_embd % config.n_head, 0) << "n_embd must be divisible by n_head"; + CHECK_EQ(config.n_head % tp_world_size, 0) << "n_head must be divisible by TP world size"; + CHECK(!config.q_lora_rank.has_value() || config.q_lora_rank.value() > 0) << "q_lora_rank must be positive when set"; + + const auto default_head_dim = config.n_embd / config.n_head; + const int64_t kv_lora_rank = config.kv_lora_rank > 0 ? config.kv_lora_rank : config.n_embd; + const int64_t qk_nope_head_dim = config.qk_nope_head_dim > 0 ? config.qk_nope_head_dim : default_head_dim; + const int64_t qk_rope_head_dim = config.qk_rope_head_dim > 0 ? config.qk_rope_head_dim : default_head_dim; + const int64_t v_head_dim = config.v_head_dim > 0 ? config.v_head_dim : default_head_dim; + + CHECK_GT(qk_nope_head_dim, 0) << "qk_nope_head_dim must be positive"; + CHECK_GT(qk_rope_head_dim, 0) << "qk_rope_head_dim must be positive"; + CHECK_GT(v_head_dim, 0) << "v_head_dim must be positive"; + CHECK_EQ(qk_rope_head_dim % 2, 0) << "qk_rope_head_dim must be even for RoPE"; + + n_head_ = config.n_head; + n_embd_ = config.n_embd; + local_n_head_ = n_head_ / tp_world_size; + + use_q_lora_ = config.q_lora_rank.has_value(); + q_lora_rank_ = config.q_lora_rank.value_or(0); + kv_lora_rank_ = kv_lora_rank; + qk_nope_head_dim_ = qk_nope_head_dim; + qk_rope_head_dim_ = qk_rope_head_dim; + qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_; + v_head_dim_ = v_head_dim; + q_down_proj_use_tp_ = config.q_down_proj_use_tp; + kv_down_proj_use_tp_ = config.kv_down_proj_use_tp; +} + +std::vector> +MLASelfAttention::Forward(const std::vector> &x) { + CHECK_GE(x.size(), 1) << "MLASelfAttention expects at least hidden states"; + + // x[0]: (B, T_local, C) + const auto B = x[0]->Dims()[0]; + const auto C = x[0]->Dims()[2]; + CHECK_EQ(C, n_embd_) << "hidden size must match n_embd"; + + // freqs_cis: (T, D_rope / 2, 2) + const auto freqs_cis = x.size() > 1 ? x[1] : nullptr; + // external_mask: (1, 1, T, T) + const auto external_mask = x.size() > 3 ? x[3] : nullptr; + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { + CHECK(freqs_cis != nullptr) << "freqs_cis is null."; + } + + const bool sequence_parallel_enabled = nn::parallel::global::GetSequenceParallelEnabled(); + + // ----------- Q PATH ----------- + // Q path, align with Megatron: + // - q_lora_rank == nullopt -> linear_q_proj directly; + // - otherwise linear_q_down_proj -> q_layernorm -> linear_q_up_proj. + std::shared_ptr q; + if (use_q_lora_) { + // linear_q_down_proj: + // non-TP path: (B, T_local, C) -> (B, T_local, R_q) + // TP path before gather: (B, T, C) -> (B, T, R_q / TP) + // - Note that ColumnParallelLinear would perform a GatherFromSPRegion in the beginning + auto q_compressed = (*modules_[kLinearQDownProjLayerName])({x[0]})[0]; + if (q_down_proj_use_tp_ && q_compressed->Dims().back() != q_lora_rank_) { + // Gather the sharded latent dimension: (B, T, R_q / TP) -> (B, T, R_q). + q_compressed = nn::parallel::GatherFromTPRegionFunc(q_compressed)[0]; + if (sequence_parallel_enabled) { + // Keep the q_up input sequence-sharded: (B, T_full, R_q) -> (B, T_local, R_q). + q_compressed = nn::parallel::ScatterToSPRegionFunc(q_compressed)[0]; + } + } + // q_layernorm preserves shape: (B, T_local, R_q) + q_compressed = (*modules_[kQLayerNormLayerName])({q_compressed})[0]; + // linear_q_up_proj: (B, T_local, R_q) -> (B, T, H_local * (D_nope + D_rope)). + q = (*modules_[kLinearQUpProjLayerName])({q_compressed})[0]; + } else { + // linear_q_proj direct path: (B, T, C) -> (B, T, H_local * (D_nope + D_rope)). + q = (*modules_[kLinearQProjLayerName])({x[0]})[0]; + } + + // T should be the full seqlen after the q projection path gathers sequence-parallel input. + const auto T = q->Dims()[1]; + // q: (B, T, H_local * D_qk) -> (B, T, H_local, D_qk) + // qk_head_dim_ = qk_nope_head_dim_ + qk_rope_head_dim_ + q = q->View({B, T, local_n_head_, qk_head_dim_}); + + // q_nope: (B, T, H_local, D_nope), q_pos_emb: (B, T, H_local, D_rope) + auto q_nope = q->Slice(-1, 0, qk_nope_head_dim_); + auto q_pos_emb = q->Slice(-1, qk_nope_head_dim_, qk_head_dim_); + + // ----------- KV PATH ----------- + // linear_kv_down_proj: + // non-TP path: (B, T_local, C) -> (B, T_local, R_kv + D_rope) + // TP path before gather: (B, T, C) -> (B, T, (R_kv + D_rope) / TP) + auto compressed_kv_with_pe = (*modules_[kLinearKVDownProjLayerName])({x[0]})[0]; + const auto kv_down_proj_out_dim = kv_lora_rank_ + qk_rope_head_dim_; + const bool kv_down_proj_output_is_sharded = compressed_kv_with_pe->Dims().back() != kv_down_proj_out_dim; + if (kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded) { + // Gather latent+RoPE dim: (B, T, (R_kv + D_rope) / TP) -> (B, T, R_kv + D_rope) + compressed_kv_with_pe = nn::parallel::GatherFromTPRegionFunc(compressed_kv_with_pe)[0]; + } + + // compressed_kv: (B, T_local, R_kv), k_pos_emb: (B, T_local, D_rope) + auto compressed_kv = compressed_kv_with_pe->Slice(-1, 0, kv_lora_rank_); + auto k_pos_emb = compressed_kv_with_pe->Slice(-1, kv_lora_rank_, kv_lora_rank_ + qk_rope_head_dim_)->Contiguous(); + const bool k_pos_emb_has_full_sequence + = kv_down_proj_use_tp_ && kv_down_proj_output_is_sharded && sequence_parallel_enabled; + if (k_pos_emb_has_full_sequence) { + // k_pos_emb already has full T; keep only compressed_kv sequence-sharded for linear_kv_up_proj. + // compressed_kv: (B, T, R_kv) -> (B, T_local, R_kv) + compressed_kv = nn::parallel::ScatterToSPRegionFunc(compressed_kv)[0]; + } else if (sequence_parallel_enabled) { + // Replicated down-proj path produces local k_pos_emb; gather it for attention. + // k_pos_emb: (B, T_local, D_rope) -> (B, T, D_rope) + k_pos_emb = nn::parallel::GatherFromSPRegionFunc(k_pos_emb)[0]; + } + // k_pos_emb: (B, T, D_rope) -> (B, T, 1, D_rope), shared across local heads. + k_pos_emb = k_pos_emb->View({B, T, 1, qk_rope_head_dim_}); + + // (B, T, R_kv) -> kv_layernorm -> linear_kv_up_proj -> (B, T, H_local * (D_nope + D_v)) + // kv_layernorm preserves compressed_kv shape: (B, T_local, R_kv) + auto kv = (*modules_[kKVLayerNormLayerName])({compressed_kv})[0]; + // linear_kv_up_proj: (B, T_local, R_kv) -> (B, T, H_local * (D_nope + D_v)) + kv = (*modules_[kLinearKVUpProjLayerName])({kv})[0]; + // kv: (B, T, H_local * (D_nope + D_v)) -> (B, T, H_local, D_nope + D_v) + kv = kv->View({B, T, local_n_head_, qk_nope_head_dim_ + v_head_dim_}); + // k_nope: (B, T, H_local, D_nope), v: (B, T, H_local, D_v) + auto k_nope = kv->Slice(-1, 0, qk_nope_head_dim_); + auto v = kv->Slice(-1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_); + + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { + // q_pos_emb: (B, T, H_local, D_rope), k_pos_emb: (B, T, 1, D_rope) + std::tie(q_pos_emb, k_pos_emb) = ApplyRotaryEmbedding(q_pos_emb, k_pos_emb, freqs_cis); + } + + // k_pos_emb: (B, T, 1, D_rope) -> (B, T, H_local, D_rope) + k_pos_emb = k_pos_emb->RepeatInterleave(local_n_head_, 2); + // q: (B, T, H_local, D_qk), k: (B, T, H_local, D_qk) + q = nn::function::Concat(std::vector>{q_nope, q_pos_emb}, -1); + auto k = nn::function::Concat(std::vector>{k_nope, k_pos_emb}, -1); + + // ----------- CORE ATTN ----------- + // q/k: (B, T, H_local, D_qk) -> (B, H_local, T, D_qk) + // v: (B, T, H_local, D_v) -> (B, H_local, T, D_v) + q = q->Transpose(1, 2); + k = k->Transpose(1, 2); + v = v->Transpose(1, 2); + + // att: (B, H_local, T, T) + auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast(qk_head_dim_))); + if (external_mask) { + att = att->MaskedFill(external_mask, std::numeric_limits::lowest()); + } else { + // mask: (1, 1, T, T) + auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1}); + att = att->MaskedFill(mask == 0, -std::numeric_limits::infinity()); + } + // att: (B, H_local, T, T) + att = nn::function::Softmax(att, -1); + + // y: (B, H_local, T, D_v) + auto y = att->Matmul(v); + // y: (B, H_local, T, D_v) -> (B, T, H_local, D_v) -> (B, T, H_local * D_v) + y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_n_head_ * v_head_dim_}); + // linear_proj: (B, T, H_local * D_v) -> (B, T, C) + y = (*modules_[kLinearProjLayerName])({y})[0]; + + return {y}; +} + +} // namespace infini_train::nn diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..9d3cf35e 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -14,6 +14,7 @@ #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" @@ -28,9 +29,11 @@ TransformerFirstStage::TransformerFirstStage(const TransformerConfig &config) modules_[kWTELayerName] = std::make_shared( config_.vocab_size, config_.n_embd, parallel::global::GetSequenceParallelEnabled()); - // LLaMA3 use RoPE, so they don't need position embedding - if (config_.activation_type == MLPType::kGELU) { + // Only learned absolute position embedding uses a trainable WPE table. + if (config_.position_embedding_type == PositionEmbeddingType::kLearnedAbsolute) { modules_[kWPELayerName] = std::make_shared(config_.block_size, config_.n_embd); + } else if (config_.position_embedding_type != PositionEmbeddingType::kRoPE) { + LOG(FATAL) << "Unsupported position embedding type"; } } @@ -44,7 +47,7 @@ std::vector> TransformerFirstStage::Forward(const std::v // (B, T) -> Embedding(V_local, C) -> (B, T, C) auto tok_emb = (*modules_[kWTELayerName])({x1}); - // Add position embedding only for models that use absolute position encoding + // Add position embedding only for models that use learned absolute position encoding. if (modules_.contains(kWPELayerName)) { // (T_local) // NOTE(zbl): Slice pos sequence when SP is enabled @@ -65,7 +68,7 @@ std::vector> TransformerFirstStage::Forward(const std::v // (B, T, C) return {tok_emb[0] + pos_emb[0]}; } else { - // For RoPE-based models (LLaMA3), no position embedding needed + // For RoPE-based models (LLaMA3), no absolute position embedding is needed. // (B, T, C) return tok_emb; } @@ -85,7 +88,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea LOG(FATAL) << "Unsupported norm type"; } - modules_[kAttnLayerName] = std::make_shared(config); + if (config.multi_latent_attention) { + modules_[kAttnLayerName] = std::make_shared(config); + } else { + modules_[kAttnLayerName] = std::make_shared(config); + } modules_[kMlpLayerName] = std::make_shared(config); } @@ -128,15 +135,17 @@ TransformerChunk::TransformerChunk(const TransformerConfig &config, int start_la std::vector> TransformerChunk::Forward(const std::vector> &x) { auto x1 = x[0]; - // Check if we need to pass RoPE parameters (for LLaMA3 style models) - if (config_.attention_type == AttentionType::kRoPE) { + // Check if we need to pass RoPE parameters (for LLaMA3 style models). + if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) { // For RoPE models, we need to prepare freqs_cis and potentially other parameters const auto device = x1->GetDevice(); // Init freqs_cis on device only once if (buffers_[kFreqsCisName] == nullptr) { - int64_t head_dim = config_.n_embd / config_.n_head; - buffers_[kFreqsCisName] = PrecomputeFreqsCis(head_dim, config_.block_size * 2, config_.rope_theta, + int64_t rope_head_dim = config_.multi_latent_attention && config_.qk_rope_head_dim > 0 + ? config_.qk_rope_head_dim + : config_.n_embd / config_.n_head; + buffers_[kFreqsCisName] = PrecomputeFreqsCis(rope_head_dim, config_.block_size * 2, config_.rope_theta, config_.use_scaled_rope, device); } @@ -156,9 +165,11 @@ std::vector> TransformerChunk::Forward(const std::vector for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1, freqs_view, start_pos_ptr, mask})[0]; } - } else { - // Standard attention (GPT2 style) + } else if (config_.position_embedding_type == PositionEmbeddingType::kLearnedAbsolute) { + // Learned absolute position embedding models (GPT-2 style). for (auto &h : *std::dynamic_pointer_cast(modules_[kHLayerName])) { x1 = (*h)({x1})[0]; } + } else { + LOG(FATAL) << "Unsupported position embedding type"; } return {x1}; @@ -212,7 +223,7 @@ TransformerModel::TransformerModel(const TransformerConfig config) modules_[kPPFirstStageName] = std::make_shared(config_); transformer[TransformerFirstStage::kWTELayerName] = modules_[kPPFirstStageName]->mutable_module(TransformerFirstStage::kWTELayerName); - if (config_.attention_type == AttentionType::kStandard) { + if (config_.position_embedding_type == PositionEmbeddingType::kLearnedAbsolute) { transformer[TransformerFirstStage::kWPELayerName] = modules_[kPPFirstStageName]->mutable_module(TransformerFirstStage::kWPELayerName); } diff --git a/infini_train/src/nn/modules/transformer/utils.cc b/infini_train/src/nn/modules/transformer/utils.cc index 98505fd0..4ec11f2d 100644 --- a/infini_train/src/nn/modules/transformer/utils.cc +++ b/infini_train/src/nn/modules/transformer/utils.cc @@ -1,5 +1,9 @@ #include "infini_train/include/nn/modules/transformer/utils.h" +#include +#include +#include + #include "glog/logging.h" #include "infini_train/include/nn/functional.h" @@ -27,4 +31,38 @@ std::shared_ptr PrecomputeFreqsCis(int64_t dim, int64_t end, float theta return freqs_cis; } + +std::tuple, std::shared_ptr> +ApplyRotaryEmbedding(const std::shared_ptr &xq, const std::shared_ptr &xk, + const std::shared_ptr &freqs_cis) { + const auto &x_shape = xq->Dims(); // (B, T, H, D) + const int64_t T = x_shape[1]; + const int64_t D = x_shape[3]; + + std::vector target_shape = {1, T, 1, D / 2, 2}; + auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2) + + auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2) + auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2) + + auto slice_pair = [](const std::shared_ptr &x) { + auto even = x->Slice(-1, 0, x->Dims().back(), 2); + auto odd = x->Slice(-1, 1, x->Dims().back(), 2); + return std::make_pair(even, odd); + }; + + auto [q_even, q_odd] = slice_pair(xq); + auto q_rotated_left = q_even * cos - q_odd * sin; + auto q_rotated_right = q_even * sin + q_odd * cos; + auto q_rotated + = nn::function::Stack(std::vector>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2); + + auto [k_even, k_odd] = slice_pair(xk); + auto k_rotated_left = k_even * cos - k_odd * sin; + auto k_rotated_right = k_even * sin + k_odd * cos; + auto k_rotated + = nn::function::Stack(std::vector>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2); + + return {q_rotated, k_rotated}; +} } // namespace infini_train diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 44ab8189..b83c5e52 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -45,6 +45,24 @@ std::shared_ptr GatherAlongFirstDim(const std::shared_ptr &tenso return gathered_output; } +std::shared_ptr ScatterAlongFirstDim(const std::shared_ptr &tensor) { + int world_size = global::GetTensorParallelSize(); + CHECK_GT(world_size, 0) << "Tensor Parallel group not initialized"; + if (world_size == 1) { + return tensor; + } + + auto device = tensor->GetDevice(); + auto tp_group = ProcessGroupFactory::Instance(device.type()) + ->Get(GetTensorParallelProcessGroupName(device.Rank().GlobalRank())); + auto rank = tp_group->GetGroupRank(device.Rank().GlobalRank()); + + CHECK_EQ(tensor->Dims()[0] % world_size, 0) << "First dimension must be divisible by TP world size"; + auto first_dim_size = tensor->Dims()[0] / world_size; + auto shards = tensor->Split(first_dim_size, 0); + return shards[rank]->Contiguous(); +} + std::shared_ptr GatherAlongLastDim(const std::shared_ptr &tensor) { int world_size = global::GetTensorParallelSize(); CHECK_GT(world_size, 0) << "Tensor Parallel group not initialized"; @@ -214,6 +232,21 @@ class ReduceScatterToSPRegion : public autograd::Function { }; }; +class ScatterToSPRegion : public autograd::Function { +public: + static constexpr char kType[] = "ScatterToSPRegionFunction"; + + explicit ScatterToSPRegion() : autograd::Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override { + return {ScatterAlongFirstDim(input_tensors[0]->Transpose(0, 1))->Transpose(0, 1)}; + }; + + std::vector> Backward(const std::vector> &grad_outputs) override { + return {GatherAlongFirstDim(grad_outputs[0]->Transpose(0, 1))->Transpose(0, 1)}; + }; +}; + class GatherFromSPRegion : public autograd::Function { public: static constexpr char kType[] = "GatherFromSPRegionFunction"; @@ -263,6 +296,10 @@ std::vector> ReduceScatterToSPRegionFunc(const std::shar return std::make_shared()->Apply({input}); } +std::vector> ScatterToSPRegionFunc(const std::shared_ptr &input) { + return std::make_shared()->Apply({input}); +} + std::vector> GatherFromSPRegionFunc(const std::shared_ptr &input) { return std::make_shared()->Apply({input}); } diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index ba62e1e3..73c684f5 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -1,16 +1,20 @@ #include #include +#include #include #include "gtest/gtest.h" +#include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mla_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" #include "tests/common/test_utils.h" @@ -98,7 +102,7 @@ TEST_P(TransformerModuleTest, StandardAttention) { config.n_embd = 64; config.n_head = 4; config.n_kv_head = 4; - config.attention_type = nn::AttentionType::kStandard; + config.position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute; config.add_bias_linear = true; auto attn = std::make_shared(config); @@ -110,6 +114,53 @@ TEST_P(TransformerModuleTest, StandardAttention) { EXPECT_EQ(output[0]->Dims(), input->Dims()); } +TEST_P(TransformerModuleTest, MLAAttention) { + SKIP_CPU(); + nn::TransformerConfig config; + config.n_embd = 64; + config.n_head = 4; + config.block_size = 16; + config.position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute; + config.add_bias_linear = true; + config.multi_latent_attention = true; + config.q_lora_rank = 32; + config.kv_lora_rank = 32; + config.qk_nope_head_dim = 8; + config.qk_rope_head_dim = 8; + config.v_head_dim = 16; + + auto attn = std::make_shared(config); + attn->To(GetDevice()); + EXPECT_FALSE(attn->Parameters().empty()); + EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), nn::Linear::kType); + EXPECT_EQ(attn->module(nn::MLASelfAttention::kLinearKVDownProjLayerName).type(), nn::Linear::kType); + + auto input = std::make_shared(std::vector{2, 8, 64}, DataType::kFLOAT32, GetDevice()); + auto output = (*attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto tp_down_config = config; + tp_down_config.q_down_proj_use_tp = true; + tp_down_config.kv_down_proj_use_tp = true; + auto tp_down_attn = std::make_shared(tp_down_config); + tp_down_attn->To(GetDevice()); + EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearQDownProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + EXPECT_EQ(tp_down_attn->module(nn::MLASelfAttention::kLinearKVDownProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + output = (*tp_down_attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto direct_q_config = config; + direct_q_config.q_lora_rank = std::nullopt; + auto direct_q_attn = std::make_shared(direct_q_config); + direct_q_attn->To(GetDevice()); + EXPECT_EQ(direct_q_attn->module(nn::MLASelfAttention::kLinearQProjLayerName).type(), + nn::parallel::ColumnParallelLinear::kType); + output = (*direct_q_attn)({input}); + EXPECT_EQ(output[0]->Dims(), input->Dims()); +} + TEST_P(TransformerModuleTest, GPT2TransformerLayer) { SKIP_CPU(); nn::TransformerConfig config; @@ -147,7 +198,7 @@ TEST_P(TransformerModuleTest, LLaMA3Model) { config.n_head = 4; config.n_kv_head = 2; config.n_embd = 64; - config.attention_type = nn::AttentionType::kRoPE; + config.position_embedding_type = nn::PositionEmbeddingType::kRoPE; config.activation_type = nn::MLPType::kSwiGLU; config.norm_type = nn::NormType::kRMSNorm; config.add_bias_linear = false; @@ -174,7 +225,7 @@ TEST_P(TransformerModuleTest, StateDict) { config.n_kv_head = 2; config.n_embd = 32; config.vocab_size = 1000; - config.attention_type = nn::AttentionType::kStandard; + config.position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute; config.activation_type = nn::MLPType::kGELU; config.norm_type = nn::NormType::kLayerNorm; config.add_bias_linear = true;