Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions example/gpt2/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ inline nn::TransformerConfig GPT2Config() {
.n_head = 12,
.n_kv_head = 12,
.n_embd = 768,
.attention_type = nn::AttentionType::kStandard,
.position_embedding_type = nn::PositionEmbeddingType::kLearnedAbsolute,
.activation_type = nn::MLPType::kGELU,
.norm_type = nn::NormType::kLayerNorm,
.add_bias_linear = true,
Expand All @@ -34,7 +34,8 @@ inline void SanitizeGPT2Config(const nn::TransformerConfig &c) {
CHECK_GT(c.n_embd, 0);
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
CHECK_EQ(c.n_kv_head, c.n_head) << "GPT-2 does not use GQA; n_kv_head must equal n_head";
CHECK(c.attention_type == nn::AttentionType::kStandard) << "GPT-2 requires standard attention";
CHECK(c.position_embedding_type == nn::PositionEmbeddingType::kLearnedAbsolute)
<< "GPT-2 requires learned absolute position embedding";
CHECK(c.activation_type == nn::MLPType::kGELU) << "GPT-2 requires GELU activation";
CHECK(c.norm_type == nn::NormType::kLayerNorm) << "GPT-2 requires LayerNorm";
}
Expand Down
4 changes: 2 additions & 2 deletions example/llama3/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ inline nn::TransformerConfig LLaMA3Config() {
.n_head = 32,
.n_kv_head = 8,
.n_embd = 2048,
.attention_type = nn::AttentionType::kRoPE,
.position_embedding_type = nn::PositionEmbeddingType::kRoPE,
.activation_type = nn::MLPType::kSwiGLU,
.norm_type = nn::NormType::kRMSNorm,
.add_bias_linear = false,
Expand All @@ -36,7 +36,7 @@ inline void SanitizeLLaMA3Config(const nn::TransformerConfig &c) {
CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA";
CHECK_GT(c.n_embd, 0);
CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head";
CHECK(c.attention_type == nn::AttentionType::kRoPE) << "LLaMA-3 requires RoPE attention";
CHECK(c.position_embedding_type == nn::PositionEmbeddingType::kRoPE) << "LLaMA-3 requires RoPE position embedding";
CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "LLaMA-3 requires SwiGLU activation";
CHECK(c.norm_type == nn::NormType::kRMSNorm) << "LLaMA-3 requires RMSNorm";
CHECK(!c.add_bias_linear) << "LLaMA-3 has no bias in linear layers";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <memory>
#include <tuple>
#include <vector>

#include "infini_train/include/nn/modules/module.h"
Expand Down Expand Up @@ -35,20 +34,6 @@ class CausalSelfAttention : public infini_train::nn::CloneableModule<CausalSelfA
// Setup method for different attention modes
void SetupAttention(const TransformerConfig &config);

// Standard attention forward (GPT2 style: no RoPE, no GQA)
std::vector<std::shared_ptr<infini_train::Tensor>>
ForwardStandard(const std::vector<std::shared_ptr<infini_train::Tensor>> &x);

// RoPE-aware attention forward (LLaMA3 style: with RoPE, optional GQA)
std::vector<std::shared_ptr<infini_train::Tensor>>
ForwardWithRoPE(const std::vector<std::shared_ptr<infini_train::Tensor>> &x);

// RoPE helper methods
std::tuple<std::shared_ptr<infini_train::Tensor>, std::shared_ptr<infini_train::Tensor>>
ApplyRotaryEmbedding(const std::shared_ptr<infini_train::Tensor> &xq,
const std::shared_ptr<infini_train::Tensor> &xk,
const std::shared_ptr<infini_train::Tensor> &freqs_cis);

// GQA helper method
std::shared_ptr<infini_train::Tensor> RepeatKV(const std::shared_ptr<infini_train::Tensor> &x, int64_t n_rep);
};
Expand Down
51 changes: 51 additions & 0 deletions infini_train/include/nn/modules/transformer/mla_self_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/modules/transformer/transformer_config.h"

namespace infini_train::nn {

class MLASelfAttention : public infini_train::nn::CloneableModule<MLASelfAttention> {
public:
static constexpr char kType[] = "MLASelfAttention";

static constexpr char kLinearQProjLayerName[] = "linear_q_proj";
static constexpr char kLinearQDownProjLayerName[] = "linear_q_down_proj";
static constexpr char kQLayerNormLayerName[] = "q_layernorm";
static constexpr char kLinearQUpProjLayerName[] = "linear_q_up_proj";
static constexpr char kLinearKVDownProjLayerName[] = "linear_kv_down_proj";
static constexpr char kKVLayerNormLayerName[] = "kv_layernorm";
static constexpr char kLinearKVUpProjLayerName[] = "linear_kv_up_proj";
static constexpr char kLinearProjLayerName[] = "linear_proj";

static constexpr char kParamBiasName[] = "bias";

explicit MLASelfAttention(const TransformerConfig &config);

std::vector<std::shared_ptr<infini_train::Tensor>>
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;

private:
TransformerConfig config_;
int64_t n_head_ = 0;
int64_t n_embd_ = 0;
int64_t local_n_head_ = 0;

int64_t q_lora_rank_ = 0;
int64_t kv_lora_rank_ = 0;
int64_t qk_nope_head_dim_ = 0;
int64_t qk_rope_head_dim_ = 0;
int64_t qk_head_dim_ = 0;
int64_t v_head_dim_ = 0;

bool use_q_lora_ = true;
bool q_down_proj_use_tp_ = false;
bool kv_down_proj_use_tp_ = false;

void SetupAttention(const TransformerConfig &config);
};

} // namespace infini_train::nn
26 changes: 20 additions & 6 deletions infini_train/include/nn/modules/transformer/transformer_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ enum class ModelType {
kLLaMA3, // LLaMA3
};

enum class AttentionType {
kStandard, // Standard attention
kRoPE // Rotary Position Embedding
enum class PositionEmbeddingType {
kLearnedAbsolute, // Megatron: learned_absolute
kRoPE, // Megatron: rope
kYarn, // Megatron: yarn
kMRoPE, // Megatron: mrope
kRelative, // Megatron: relative
kNone // Megatron: none
};

enum class MLPType {
Expand All @@ -34,9 +38,9 @@ struct TransformerConfig {
int64_t n_kv_head = 12; // Num of Key/Value heads (<= n_head, < n_head if using GQA)
int64_t n_embd = 768; // Hidden size

AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type
MLPType activation_type = MLPType::kGELU; // MLP activation type
NormType norm_type = NormType::kLayerNorm; // Normalization type
PositionEmbeddingType position_embedding_type = PositionEmbeddingType::kLearnedAbsolute; // Position embedding type.
MLPType activation_type = MLPType::kGELU; // MLP activation type
NormType norm_type = NormType::kLayerNorm; // Normalization type

bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block,
// including: attention QKV projection, attention output projection, MLP FC layers (and
Expand All @@ -53,6 +57,16 @@ struct TransformerConfig {
float rope_theta = 500000.0f; // theta in RoPE
bool use_scaled_rope = false; // scaled RoPE

// MLA config
bool multi_latent_attention = false; // Use MLA instead of standard causal self-attention.
std::optional<int64_t> q_lora_rank = std::nullopt; // nullopt means direct linear_q_proj path.
int64_t kv_lora_rank = 0; // 0 falls back to n_embd in MLASelfAttention.
int64_t qk_nope_head_dim = 0; // 0 falls back to n_embd / n_head.
int64_t qk_rope_head_dim = 0; // 0 falls back to n_embd / n_head.
int64_t v_head_dim = 0; // 0 falls back to n_embd / n_head.
bool q_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_q_down_proj.
bool kv_down_proj_use_tp = false; // Use ColumnParallelLinear for linear_kv_down_proj.

// Normalization
float norm_eps = 1e-5f; // epsilon in RMSNorm

Expand Down
6 changes: 6 additions & 0 deletions infini_train/include/nn/modules/transformer/utils.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
#pragma once

#include <cstdint>
#include <memory>
#include <tuple>

#include "infini_train/include/tensor.h"

namespace infini_train {
// RoPE helper method
std::shared_ptr<Tensor> PrecomputeFreqsCis(int64_t dim, int64_t end, float theta = 10000.0f, bool use_scaled = false,
Device device = Device());

std::tuple<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>
ApplyRotaryEmbedding(const std::shared_ptr<Tensor> &xq, const std::shared_ptr<Tensor> &xk,
const std::shared_ptr<Tensor> &freqs_cis);
} // namespace infini_train
1 change: 1 addition & 0 deletions infini_train/include/nn/parallel/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ std::vector<int> GetPipelineParallelGroupRanks(int global_rank);

// TP/SP Communication Helper Functions
std::vector<std::shared_ptr<Tensor>> GatherFromTPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> ScatterToSPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> ReduceScatterToSPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> GatherFromSPRegionFunc(const std::shared_ptr<Tensor> &input);
std::vector<std::shared_ptr<Tensor>> ScatterToTPRegionFunc(const std::shared_ptr<Tensor> &input);
Expand Down
147 changes: 29 additions & 118 deletions infini_train/src/nn/modules/transformer/causal_self_attention.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "infini_train/include/nn/modules/transformer/causal_self_attention.h"

#include <cmath>
#include <limits>
#include <memory>
#include <tuple>
#include <vector>
Expand All @@ -12,6 +13,7 @@
#include "infini_train/include/nn/modules/normalization.h"
#include "infini_train/include/nn/modules/sparse.h"
#include "infini_train/include/nn/modules/transformer/transformer_config.h"
#include "infini_train/include/nn/modules/transformer/utils.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"
Expand Down Expand Up @@ -42,12 +44,9 @@ CausalSelfAttention::CausalSelfAttention(const TransformerConfig &config) : Clon
/*skip_bias_add=*/false,
/*sequence_parallel=*/nn::parallel::global::GetSequenceParallelEnabled());

// For standard attention (GPT2 style), precompute causal mask
if (config_.attention_type == AttentionType::kStandard) {
// causal mask: (1, 1, block_size, block_size)
buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size}))
->View({1, 1, config_.block_size, config_.block_size});
}
// causal mask: (1, 1, block_size, block_size)
buffers_[kParamBiasName] = function::Tril(nn::function::Ones({config_.block_size, config_.block_size}))
->View({1, 1, config_.block_size, config_.block_size});
}

void CausalSelfAttention::SetupAttention(const TransformerConfig &config) {
Expand Down Expand Up @@ -76,125 +75,21 @@ void CausalSelfAttention::SetupAttention(const TransformerConfig &config) {

std::vector<std::shared_ptr<infini_train::Tensor>>
CausalSelfAttention::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
if (config_.attention_type == AttentionType::kRoPE) {
return ForwardWithRoPE(x);
} else {
return ForwardStandard(x);
}
}

std::vector<std::shared_ptr<infini_train::Tensor>>
CausalSelfAttention::ForwardStandard(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
auto tp_world_size = parallel::global::GetTensorParallelSize();

const auto B = x[0]->Dims()[0]; // bs
const auto C = x[0]->Dims()[2]; // n_embd
const int64_t head_dim = n_embd_ / n_head_; // per-head dim (global)
const int64_t local_C = n_embd_ / tp_world_size; // per-rank hidden

// (B, T, C) -> ColumnParallelLinear(C, 3*C) -> (B, T, 3 * local_C)
// -> Split -> (3, B, T, local_C)
auto qkv = (*modules_[kCAttnLayerName])(x)[0]->Split(local_C, 2);

// (B, T, local_C)
auto q = qkv[0];
auto k = qkv[1];
auto v = qkv[2];

// NOTE(zbl): Acquire full T after AllGather is performed in ColumnParallelLinear
const auto T = q->Dims()[1];

// View to multi-head: local_n_head * head_dim == local_C
// (B, T, local_C) -> (B, T, h_l, Dh) -> (B, h_l, T, Dh)
k = k->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
q = q->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);
v = v->View({B, T, local_n_head_, head_dim})->Transpose(1, 2);

// (B, h_l, T, T)
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(head_dim));
// (1, 1, T, T)
auto mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
// (1, 1, T, T) -> eq 0 -> (1, 1, T, T) -> masked_fill -> (B, h_l, T, T)
att = att->MaskedFill(mask == 0, -std::numeric_limits<float>::infinity());
// (B, h_l, T, T)
att = nn::function::Softmax(att, -1);
// (B, h_l, T, Dh)
auto y = att->Matmul(v);
// (B, h_l, T, Dh) -> (B, T, h_l, Dh) -> (B, T, local_C)
y = y->Transpose(1, 2)->Contiguous()->View({B, T, local_C});

// Get full tensor
// (B, T, local_C) -> RowParallelLinear(n_embd, n_embd) -> (B, T, C)
y = (*modules_[kCProjLayerName])({y})[0];
// (B, T, C) == (bs, seq_len, n_embd)
return {y};
}

// RoPE helper methods
std::tuple<std::shared_ptr<infini_train::Tensor>, std::shared_ptr<infini_train::Tensor>>
CausalSelfAttention::ApplyRotaryEmbedding(const std::shared_ptr<infini_train::Tensor> &xq,
const std::shared_ptr<infini_train::Tensor> &xk,
const std::shared_ptr<infini_train::Tensor> &freqs_cis) {
// Reshape freqs_cis for broadcasting
const auto &x_shape = xq->Dims(); // (B, T, H, D)
const int64_t T = x_shape[1];
const int64_t D = x_shape[3];

std::vector<int64_t> target_shape = {1, T, 1, D / 2, 2};
auto cos_sin = freqs_cis->View(target_shape); // -> (1, T, 1, D/2, 2)

auto cos = cos_sin->Slice(-1, 0, 1, 1)->Squeeze(-1); // (1, T, 1, D/2)
auto sin = cos_sin->Slice(-1, 1, 2, 1)->Squeeze(-1); // (1, T, 1, D/2)

auto slice_pair = [](const std::shared_ptr<Tensor> &x) {
auto even = x->Slice(-1, 0, x->Dims().back(), 2);
auto odd = x->Slice(-1, 1, x->Dims().back(), 2);
return std::make_pair(even, odd);
};

auto [q_even, q_odd] = slice_pair(xq);
auto q_rotated_left = q_even * cos - q_odd * sin;
auto q_rotated_right = q_even * sin + q_odd * cos;
auto q_rotated
= nn::function::Stack(std::vector<std::shared_ptr<Tensor>>{q_rotated_left, q_rotated_right}, -1)->Flatten(-2);

auto [k_even, k_odd] = slice_pair(xk);
auto k_rotated_left = k_even * cos - k_odd * sin;
auto k_rotated_right = k_even * sin + k_odd * cos;
auto k_rotated
= nn::function::Stack(std::vector<std::shared_ptr<Tensor>>{k_rotated_left, k_rotated_right}, -1)->Flatten(-2);

return {q_rotated, k_rotated};
}

std::shared_ptr<infini_train::Tensor> CausalSelfAttention::RepeatKV(const std::shared_ptr<infini_train::Tensor> &x,
int64_t n_rep) {
const auto &shape = x->Dims();
const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3];

if (n_rep == 1) {
return x;
}

return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D});
}

std::vector<std::shared_ptr<infini_train::Tensor>>
CausalSelfAttention::ForwardWithRoPE(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
const auto B = x[0]->Dims()[0]; // bs
const auto C = x[0]->Dims()[2]; // n_embd

const auto tp_size = nn::parallel::global::GetTensorParallelSize();

const auto C_local = C / tp_size;
const auto H_local = n_head_ / tp_size;
const auto H_local = local_n_head_;
const auto KV_local = n_kv_head_ / tp_size;
const auto D = head_dim_; // n_embd / n_head

const auto freqs_cis = x.size() > 1 ? x[1] : nullptr;
const auto start_pos = x.size() > 2 ? x[2] : nullptr;
const auto mask = x.size() > 3 ? x[3] : nullptr;
CHECK(freqs_cis != nullptr) << "freqs_cis is null.";
if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) {
CHECK(freqs_cis != nullptr) << "freqs_cis is null.";
}

// (B, T, C) -> (B, T, (H + 2 * n_kv_head) * D)
auto qkv = (*modules_[kCAttnLayerName])({x[0]})[0];
Expand All @@ -212,10 +107,10 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector<std::shared_ptr<infini_tr
// v: (B, T, KV_local, D)
auto v = qkv->Slice(2, q_size_local + kv_size_local, q_size_local + 2 * kv_size_local)->View({B, T, KV_local, D});

// -> RoPE on q, k
// q: (B, T, H_local, D)
// k: (B, T, KV_local, D)
std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis);
if (config_.position_embedding_type == PositionEmbeddingType::kRoPE) {
// q: (B, T, H_local, D), k: (B, T, KV_local, D)
std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis);
}

// TODO(zbl): use kv cache during inference
// if (use_kv_) { ... }
Expand Down Expand Up @@ -243,6 +138,10 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector<std::shared_ptr<infini_tr
if (mask) {
// mask: (1, 1, T, T)
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
} else {
// fallback causal mask: (1, 1, T, T)
auto causal_mask = buffers_[kParamBiasName]->Slice({0, 0, 0, 0}, {1, 1, T, T}, {1, 1, 1, 1});
att = att->MaskedFill(causal_mask == 0, -std::numeric_limits<float>::infinity());
}
// (B, H_local, T, T)
att = nn::function::Softmax(att, -1);
Expand All @@ -257,4 +156,16 @@ CausalSelfAttention::ForwardWithRoPE(const std::vector<std::shared_ptr<infini_tr
return {y};
}

std::shared_ptr<infini_train::Tensor> CausalSelfAttention::RepeatKV(const std::shared_ptr<infini_train::Tensor> &x,
int64_t n_rep) {
const auto &shape = x->Dims();
const int64_t B = shape[0], T = shape[1], H = shape[2], D = shape[3];

if (n_rep == 1) {
return x;
}

return x->View({B, T, H, 1, D})->RepeatInterleave(n_rep, 3)->Contiguous()->View({B, T, H * n_rep, D});
}

} // namespace infini_train::nn
Loading
Loading