From a7df3a6a9b59e33de26ff3ccaab04bed495de5b7 Mon Sep 17 00:00:00 2001 From: chen Date: Tue, 28 Apr 2026 16:22:45 +0000 Subject: [PATCH 1/3] refactor: fuse base+LoRA matmuls before collective to fix loss divergence Inline base and LoRA matmuls, add locally, then issue a single AllGather/AllReduce instead of two separate collective ops. The prior two-collective approach caused floating-point divergence in DDP loss. Also fix LoadLoRAWeights to slice sharded tensors by tp_rank when the checkpoint shape differs from the partitioned model shape. --- .../src/nn/lora/lora_parallel_linear.cc | 104 ++++++++---------- infini_train/src/nn/lora/lora_utils.cc | 25 ++++- 2 files changed, 68 insertions(+), 61 deletions(-) diff --git a/infini_train/src/nn/lora/lora_parallel_linear.cc b/infini_train/src/nn/lora/lora_parallel_linear.cc index 760ed3d8..1b160c80 100644 --- a/infini_train/src/nn/lora/lora_parallel_linear.cc +++ b/infini_train/src/nn/lora/lora_parallel_linear.cc @@ -126,39 +126,35 @@ LoRAColumnParallelLinear::Forward(const std::vector> &in << "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training."; if (!merged_) { - // 1. Compute base output via parent class - auto base_result = ColumnParallelLinear::Forward(input_tensors); - auto base_output = base_result[0]; - - // 2. Compute LoRA output using the SAME input that base module uses - // Match base input path exactly: use direct input if input_is_parallel_ or sequence_parallel_, - // otherwise copy to TP region - auto lora_input = (input_is_parallel_ || sequence_parallel_) - ? input_tensors[0] - : parallel::CopyToTPRegionFunc(input_tensors[0])[0]; + // Inline base + LoRA matmuls, add locally, then single collective op. + // This avoids 2 separate AllGather ops which cause floating-point divergence. + auto input = (input_is_parallel_ || sequence_parallel_) ? input_tensors[0] + : parallel::CopyToTPRegionFunc(input_tensors[0])[0]; if (sequence_parallel_) { - // Base uses GatherFromSPRegionFunc to gather sequence dimension - lora_input = parallel::GatherFromSPRegionFunc(lora_input)[0]; + input = parallel::GatherFromSPRegionFunc(input)[0]; } - // Compute LoRA: lora_A: [rank, in_features], lora_B: [out_per_partition, rank] - auto lora_proj = std::make_shared()->Apply({lora_input, parameters_[kParamLoraAName]})[0]; + // Base matmul (bias folded in when applicable, matching ColumnParallelLinear::Forward) + auto base_shard = std::make_shared()->Apply( + (bias_ && !skip_bias_add_) + ? std::vector>{input, parameters_.at(kParamWeightName), + parameters_[kParamBiasName]} + : std::vector>{input, parameters_.at(kParamWeightName)})[0]; + + // LoRA matmul (local) + // Wrap replicated lora_A through CopyToTPRegion so its gradient gets AllReduced in backward + auto lora_A = parallel::CopyToTPRegionFunc(parameters_[kParamLoraAName])[0]; + auto lora_proj = std::make_shared()->Apply({input, lora_A})[0]; auto lora_output = std::make_shared()->Apply({lora_proj, parameters_[kParamLoraBName]})[0]; - // Match base output layout (gather if base gathers) - if (gather_output_) { - lora_output = parallel::GatherFromTPRegionFunc(lora_output)[0]; - } - - auto scaled_lora = lora_output->Mul(config_.Scaling()); + // Local add before collective + auto combined = base_shard->Add(lora_output->Mul(config_.Scaling())); - // 3. Add LoRA contribution to base output - // Both should now have the same sequence dimension - auto output = base_output->Add(scaled_lora); + // Single collective op + auto output = gather_output_ ? parallel::GatherFromTPRegionFunc(combined)[0] : combined; - // Return in same format as base module return skip_bias_add_ - ? std::vector>{output, bias_ ? parameters_[kParamBiasName] : nullptr} + ? std::vector>{output, bias_ ? parameters_.at(kParamBiasName) : nullptr} : std::vector>{output}; } @@ -321,42 +317,32 @@ LoRARowParallelLinear::Forward(const std::vector> &input << "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training."; if (!merged_) { - // Get effective input - match what base module uses - auto effective_input = input_tensors[0]; - const int64_t in_dim = effective_input->Dims().back(); - - if (!input_is_parallel_) { - // base would scatter; lora must match - effective_input = parallel::ScatterToTPRegionFunc(effective_input)[0]; - CHECK_EQ(effective_input->Dims().back(), in_features_per_partition_); - } else { - // input_is_parallel_=true means caller promised shard input - CHECK_EQ(in_dim, in_features_per_partition_) - << "RowParallel expects sharded input when input_is_parallel_=true. " - << "Got full in_dim=" << in_dim << " (likely upstream gathered TP output)."; + // Inline base + LoRA matmuls, add locally, then single collective op. + // This avoids 2 separate AllReduce ops which cause floating-point divergence. + auto input = input_is_parallel_ ? input_tensors[0] : parallel::ScatterToTPRegionFunc(input_tensors[0])[0]; + + // Base matmul (no bias — RowParallel adds bias AFTER collective) + auto base_shard = std::make_shared()->Apply({input, parameters_.at(kParamWeightName)})[0]; + + // LoRA matmul (local) + // Wrap replicated lora_B through CopyToTPRegion so its gradient gets AllReduced in backward + auto lora_proj = std::make_shared()->Apply({input, parameters_[kParamLoraAName]})[0]; + auto lora_B = parallel::CopyToTPRegionFunc(parameters_[kParamLoraBName])[0]; + auto lora_output = std::make_shared()->Apply({lora_proj, lora_B})[0]; + + // Local add before collective + auto combined = base_shard->Add(lora_output->Mul(config_.Scaling())); + + // Single collective op + auto output = reduce_output_ ? (sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(combined)[0] + : parallel::ReduceFromTPRegionFunc(combined)[0]) + : combined; + + // Bias after collective (matching RowParallelLinear::Forward) + if (bias_ && !skip_bias_add_) { + output = output->Add(parameters_[kParamBiasName]); } - // 1) base output - use effective_input - auto base_result = RowParallelLinear::Forward({effective_input}); - auto base_output = base_result[0]; - - // 2) lora branch uses the SAME effective_input - auto lora_proj - = std::make_shared()->Apply({effective_input, parameters_[kParamLoraAName]})[0]; - auto lora_output = std::make_shared()->Apply({lora_proj, parameters_[kParamLoraBName]})[0]; - - // 3) apply same reduction as base - auto lora_out = lora_output; - if (reduce_output_) { - lora_out = sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(lora_out)[0] - : parallel::ReduceFromTPRegionFunc(lora_out)[0]; - } - - auto scaled_lora = lora_out->Mul(config_.Scaling()); - CHECK_EQ(base_output->NumElements(), scaled_lora->NumElements()); - auto output = base_output->Add(scaled_lora); - - // Return in same format as base module return skip_bias_add_ ? std::vector>{output, bias_ ? parameters_[kParamBiasName] : nullptr} : std::vector>{output}; diff --git a/infini_train/src/nn/lora/lora_utils.cc b/infini_train/src/nn/lora/lora_utils.cc index 7b8f3668..56f5f012 100644 --- a/infini_train/src/nn/lora/lora_utils.cc +++ b/infini_train/src/nn/lora/lora_utils.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/lora/lora_parallel_linear.h" #include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -392,10 +393,30 @@ void LoadLoRAWeights(std::shared_ptr model, const std::string &filepath) auto cpu_tensor = std::make_shared(dims, DataType::kFLOAT32, Device(Device::DeviceType::kCPU, 0)); file.read(reinterpret_cast(cpu_tensor->DataPtr()), num_elements * sizeof(float)); - // Load into model + // Load into model, slicing sharded tensors by tp_rank if shapes differ auto it = model_state_dict.find(name); if (it != model_state_dict.end()) { - it->second->CopyFrom(cpu_tensor); + auto &dst = it->second; + const auto &dst_dims = dst->Dims(); + if (dst_dims == dims) { + dst->CopyFrom(cpu_tensor); + } else { + // Determine which dim is sharded: find first dim where sizes differ + int shard_dim = -1; + for (int d = 0; d < static_cast(dims.size()); ++d) { + if (d < static_cast(dst_dims.size()) && dst_dims[d] != dims[d]) { + shard_dim = d; + break; + } + } + CHECK(shard_dim >= 0) << "LoadLoRAWeights: shape mismatch for " << name + << " but no differing dim found"; + int tp_size = parallel::global::GetTensorParallelSize(); + int64_t shard_size = dims[shard_dim] / tp_size; + int64_t start = parallel::tp_rank * shard_size; + auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size); + dst->CopyFrom(sliced); + } } else { LOG(WARNING) << "LoRA parameter not found in model: " << name; } From 94671668ac4074ba36dba60fdd8a6384aa3792ba Mon Sep 17 00:00:00 2001 From: chen Date: Wed, 29 Apr 2026 09:03:19 +0000 Subject: [PATCH 2/3] fix: broadcast lora_A init from TP rank 0 to ensure consistent replicated weights --- .../src/nn/lora/lora_parallel_linear.cc | 35 ++++++++++++++----- scripts/run_models_and_profile.bash | 7 ++-- scripts/test_config.json | 34 +++++++++++------- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/infini_train/src/nn/lora/lora_parallel_linear.cc b/infini_train/src/nn/lora/lora_parallel_linear.cc index 1b160c80..595ad2ca 100644 --- a/infini_train/src/nn/lora/lora_parallel_linear.cc +++ b/infini_train/src/nn/lora/lora_parallel_linear.cc @@ -11,6 +11,7 @@ #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/tensor.h" @@ -89,22 +90,38 @@ LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr(std::vector{config_.rank, in_features_}, DataType::kFLOAT32, device_) ->RequiresGrad(); - if (config_.use_kaiming_a) { - init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + + if (parallel::global::GetTensorParallelSize() > 1) { + const auto global_rank = device_.Rank().GlobalRank(); + auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type()) + ->Get(parallel::GetTensorParallelProcessGroupName(global_rank)); + const int tp_rank = tp_group->GetGroupRank(global_rank); + + // Only TP rank 0 generates random values; others zero-init. + // AllReduce(sum) then broadcasts rank-0's values to all TP ranks. + if (tp_rank == 0) { + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } + } else { + init::Zeros(parameters_[kParamLoraAName]); + } + tp_group->AllReduce(parameters_[kParamLoraAName]); } else { - init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } } - // lora_B: [out_per_partition, rank] - sharded like base weight parameters_[kParamLoraBName] = std::make_shared(std::vector{out_features_per_partition_, config_.rank}, DataType::kFLOAT32, device_) diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index e3c67293..95046bff 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -156,8 +156,9 @@ run_and_log() { > "$log_path" fi - # Write the current run command to the log - echo "[COMMAND] $cmd" >> "$log_path" + # Write the current run command to the log (expand $LORA_WEIGHTS_DIR) + local expanded_cmd="${cmd//\$LORA_WEIGHTS_DIR/$LORA_WEIGHTS_DIR}" + echo "[COMMAND] $expanded_cmd" >> "$log_path" # Run the command and append both stdout and stderr to the log file if ! eval "$cmd" >> "$log_path" 2>&1; then @@ -272,10 +273,12 @@ for ((id=0; id Date: Fri, 22 May 2026 09:20:19 +0000 Subject: [PATCH 3/3] feat: add rank-aware Broadcast/Scatter to ProcessGroup Introduce new multi-stream Broadcast and Scatter APIs that take a root_rank_in_group argument, and rename the legacy single-stream variants to BroadCast_/Scatter_ to disambiguate. --- docs/lora_usage_guide.md | 776 ++++-------------- .../include/nn/parallel/process_group.h | 15 +- infini_train/src/autograd/comm.cc | 4 +- infini_train/src/nn/parallel/process_group.cc | 116 ++- 4 files changed, 308 insertions(+), 603 deletions(-) diff --git a/docs/lora_usage_guide.md b/docs/lora_usage_guide.md index cc41a981..1eda48ae 100644 --- a/docs/lora_usage_guide.md +++ b/docs/lora_usage_guide.md @@ -1,715 +1,301 @@ # LoRA 使用说明 -本文档介绍如何在 InfiniTrain 中使用 LoRA (Low-Rank Adaptation) 进行高效微调。 - -## 目录 - -1. [快速开始](#快速开始) -2. [核心概念](#核心概念) -3. [命令行使用](#命令行使用-gpt2-示例) -4. [LoRAModel 包装器](#lora模型-包装器-推荐模式) -5. [API 参考](#api-参考) -6. [使用示例](#使用示例) -7. [最佳实践](#最佳实践) -8. [常见问题](#常见问题) +本文档描述 InfiniTrain 当前 LoRA 实现的实际用法。当前实现采用 PEFT-style **原地注入**: +`GetLoRAModel(model, config)` 会替换匹配的 Linear 模块、冻结非 LoRA 参数,并返回修改后的 +`std::shared_ptr`。仓库里没有独立的 `LoRAModel` 包装器类或 `lora_model.h`。 ## 快速开始 -### 头文件引入 +### 头文件 ```cpp -#include "nn/lora/lora_config.h" -#include "nn/lora/lora_linear.h" -#include "nn/lora/lora_utils.h" -// 如果使用张量并行 -#include "nn/lora/lora_parallel_linear.h" -// 如果使用 LoRAModel 包装器 -#include "nn/lora/lora_model.h" +#include "infini_train/include/nn/lora/lora_config.h" +#include "infini_train/include/nn/lora/lora_linear.h" +#include "infini_train/include/nn/lora/lora_parallel_linear.h" +#include "infini_train/include/nn/lora/lora_utils.h" ``` -### 最简示例 +### 最小示例 ```cpp using namespace infini_train::nn::lora; -// 1. 创建 LoRA 配置 LoRAConfig config; -config.rank = 8; // 低秩维度 -config.alpha = 16.0f; // 缩放因子 +config.rank = 8; +config.alpha = 16.0f; +config.target_modules = ParseLoRATargetModules("c_attn,attn.c_proj"); -// 2. 获取 LoRA 模型 (原地修改,自动冻结基础模型) -auto lora_model = GetLoRAModel(model, config); +// 原地注入 LoRA,并冻结非 LoRA 参数。 +model = GetLoRAModel(model, config); +PrintLoRASummary(model); -// 3. 获取可训练参数用于优化器 -auto trainable_params = nn::lora::GetLoRAParameters(lora_model); -auto optimizer = std::make_shared(trainable_params, lr); +auto params = GetLoRAParameters(model); +auto optimizer = optimizers::Adam::Create(/*learning_rate=*/1e-4)(params); -// 4. 训练循环 for (int step = 0; step < num_steps; ++step) { - auto loss = (*lora_model)(inputs); + optimizer->ZeroGrad(); + auto logits = (*model)({input})[0]; + auto loss = (*loss_fn)({logits, labels})[0]; loss->Backward(); optimizer->Step(); - optimizer->ZeroGrad(); } -// 5. 保存 LoRA 权重 -SaveLoRAWeights(lora_model, "lora_weights.bin"); +SaveLoRAWeights(model, "adapter_lora.bin"); ``` -## 核心概念 +## 核心行为 -### LoRA 原理 +LoRA 对 Linear 层追加低秩增量: -LoRA 通过低秩分解来近似权重更新: - -``` -原始: y = Wx + b -LoRA: y = Wx + b + (α/r) × x × A^T × B^T +```text +y = x @ W^T + b + (alpha / rank) * x @ A^T @ B^T ``` -其中: -- `W` 是冻结的原始权重 -- `A` 是形状为 `[rank, in_features]` 的可训练矩阵 -- `B` 是形状为 `[out_features, rank]` 的可训练矩阵 -- `α/r` 是缩放因子 - -### 参数效率 - -假设原始 Linear 层参数量为 `in × out`,LoRA 只需训练 `rank × (in + out)` 个参数。 - -例如:`in=4096, out=4096, rank=8` -- 原始参数:16,777,216 -- LoRA 参数:65,536 (仅 0.39%) +- `W` 和 `bias` 来自原基础层,注入后默认冻结。 +- `lora_A` 形状为 `[rank, in_features]`,默认 Kaiming uniform 初始化。 +- `lora_B` 形状为 `[out_features, rank]`,零初始化,因此注入初始时不改变基础模型输出。 +- LoRA 参数当前固定创建为 `float32`。 +- `dropout` 字段存在于 `LoRAConfig`,当前未实现。 -## LoRAModel 包装器类 +`GetLoRAModel` 会遍历 `NamedModules()`,按 `target_modules` 匹配模块路径: -### LoRAModel - -遵循 PEFT 模式的 LoRA 包装器,封装基础模型和 LoRA 配置。使用 `NamedModules()` 自动遍历模型层次结构。 - -```cpp -class LoRAModel : public Module { -public: - // 构造函数 - 自动遍历模型层次结构 - LoRAModel(std::shared_ptr base_model, - const LoRAConfig &config); - - // 获取所有参数 - std::vector> Parameters() const override; - - // LoRA 权重管理 - void SaveLoRA(const std::string &filepath) const; - void LoadLoRA(const std::string &filepath); - void Merge(); - void Unmerge(); - bool IsMerged() const; - - // 打印摘要 - void PrintSummary() const; - - // 访问基础模型 - std::shared_ptr base_model() const; - - // 获取 LoRA 配置 - const LoRAConfig &config() const; -}; -``` - -### 工厂函数 - -```cpp -template -std::shared_ptr CreateLoRAModel( - const ConfigType &model_config, - const LoRAConfig &lora_config) { - auto base_model = std::make_shared(model_config); - return std::make_shared(base_model, lora_config); -} -``` +- 匹配规则是完整组件后缀匹配,例如 `attn.c_proj` 可匹配 + `transformer.h.0.attn.c_proj`。 +- `c_proj` 会同时匹配 attention 和 MLP 中名为 `c_proj` 的层。 +- `attn.c_proj` 只匹配 attention output projection。 +- 若根模块本身匹配,返回值可能是新的 LoRA 模块,因此应使用返回的 `model`。 ## API 参考 -### LoRAConfig - 配置结构 +### LoRAConfig ```cpp struct LoRAConfig { - int64_t rank = 8; // 低秩维度 r - float alpha = 16.0f; // 缩放因子 α - float dropout = 0.0f; // Dropout 概率(暂未实现) - - // 目标模块名称(默认只对 attention 层应用) - // 注意:匹配模块名的后缀,而非完整路径 + int64_t rank = 8; + float alpha = 16.0f; + float dropout = 0.0f; // not implemented std::unordered_set target_modules = {"c_attn", "c_proj"}; + bool use_kaiming_a = true; + float kaiming_a_param = sqrtf(5.0f); - // 初始化参数 - bool use_kaiming_a = true; // A 矩阵使用 Kaiming 初始化 - float kaiming_a_param = sqrtf(5.0f); // Kaiming 初始化参数 (默认值与 PyTorch 一致) + LoRAConfig() = default; + LoRAConfig(int64_t r, float a, float d, + const std::unordered_set &targets); - // 计算缩放因子 - float Scaling() const; // 返回 alpha / rank - - // 检查模块是否应该应用 LoRA - // 匹配规则:模块名以 target_modules 中的任意一个结尾 + float Scaling() const; bool ShouldApplyLoRA(const std::string &module_name) const; }; ``` -### 模型应用函数 - -#### GetLoRAModel - -PEFT-style 运行时包装器,使用 `NamedModules()` 自动遍历模型层次结构,创建 LoRA 模型。 +推荐用法: ```cpp -std::shared_ptr GetLoRAModel( - std::shared_ptr model, // 目标模型 - const LoRAConfig &config // LoRA 配置 -); +LoRAConfig config; +config.rank = 8; +config.alpha = 16.0f; +config.target_modules = ParseLoRATargetModules("c_attn,attn.c_proj"); ``` -**参数说明:** -- `model`: 要包装的模型 -- `config`: LoRA 配置(包含 `target_modules` 指定目标层) +### 注入与参数管理 -**返回值:** `std::shared_ptr`,返回原地修改后的模型(已注入 LoRA 层并冻结基础模型) - -**使用示例:** ```cpp -// 配置 LoRA -LoRAConfig config{8, 16.0f}; -// 使用 ParseLoRATargetModules 解析逗号分隔的字符串 -config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); // 只对 attention -// config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_proj"); // 包含 MLP - -// 一行启用 LoRA (原地修改,自动冻结基础模型) -auto lora_model = nn::lora::GetLoRAModel(model, config); -``` - -#### InjectLoRALayers +std::shared_ptr GetLoRAModel(std::shared_ptr model, + const LoRAConfig &config); -使用 `NamedModules()` 自动遍历模型层次结构,将 LoRA 注入到所有匹配的层中。 +std::shared_ptr InjectLoRALayers(std::shared_ptr model, + const LoRAConfig &config); -```cpp -void InjectLoRALayers( - std::shared_ptr model, // 目标模型 - const LoRAConfig &config // LoRA 配置 -); -``` - -**参数说明:** -- `model`: 要注入 LoRA 的模型 -- `config`: LoRA 配置(通过 `target_modules` 指定目标层) - -### 参数管理函数 - -#### FreezeBaseModel / UnfreezeModel - -```cpp -// 冻结基础模型所有参数 void FreezeBaseModel(std::shared_ptr model); - -// 解冻所有参数 void UnfreezeModel(std::shared_ptr model); -``` -#### GetLoRAParameters / GetBaseParameters +std::vector> +GetLoRAParameters(const std::shared_ptr &model); -```cpp -// 获取 LoRA 参数(用于优化器) -std::vector> GetLoRAParameters( - const std::shared_ptr &model); - -// 获取基础模型参数 -std::vector> GetBaseParameters( - const std::shared_ptr &model); +std::vector> +GetBaseParameters(const std::shared_ptr &model); ``` -### 权重合并函数 +- `GetLoRAModel` = 注入 + 冻结基础参数 + 重新打开 LoRA 参数。 +- `InjectLoRALayers` 只做结构替换,不冻结参数。 +- 示例训练入口在启用 LoRA 时只把 `GetLoRAParameters(model)` 传给优化器。 +- 对 merged 状态的 LoRA 模块调用 `GetLoRAParameters` 会触发检查;继续训练前先 + `UnmergeLoRAWeights(model)`。 -#### MergeLoRAWeights / UnmergeLoRAWeights +### 合并、卸载、保存和加载 ```cpp -// 合并 LoRA 权重到基础权重: W' = W + (α/r) × B × A void MergeLoRAWeights(std::shared_ptr model); - -// 恢复原始基础权重 void UnmergeLoRAWeights(std::shared_ptr model); - -// 合并权重并卸载 LoRA 模块,返回纯基础模型 std::shared_ptr MergeAndUnload(std::shared_ptr model); -``` - -**使用场景:** -- 推理时合并权重可以消除额外计算开销 -- 导出模型时合并权重得到标准模型格式 -- `MergeAndUnload`: 导出完整的标准模型,替换所有 LoRA 模块为普通 Linear 层 - -### 保存/加载函数 - -```cpp -// 保存 LoRA 权重到文件 -void SaveLoRAWeights(const std::shared_ptr &model, - const std::string &filepath); - -// 从文件加载 LoRA 权重 -void LoadLoRAWeights(std::shared_ptr model, - const std::string &filepath); -// 获取 LoRA 状态字典 std::unordered_map> LoRAStateDict(const std::shared_ptr &model); -// 加载 LoRA 状态字典 void LoadLoRAStateDict( std::shared_ptr model, const std::unordered_map> &state_dict); -``` - -### 统计函数 - -```cpp -// 打印 LoRA 模型摘要 -void PrintLoRASummary(const std::shared_ptr &model); - -// 统计可训练参数数量 -int64_t CountTrainableParameters(const std::shared_ptr &model); - -// 统计总参数数量 -int64_t CountTotalParameters(const std::shared_ptr &model); -``` - -### 工具函数 - -```cpp -// 解析逗号分隔的目标模块字符串 -std::unordered_set ParseLoRATargetModules(const std::string &targets); - -// 示例: "c_attn,c_proj" -> {"c_attn", "c_proj"} -``` - -### 张量并行 LoRA 类 - -当使用张量并行 (TP) 时,`GetLoRAModel` 会自动检测并使用对应的 LoRA 包装器: - -```cpp -// LoRA for ColumnParallelLinear (e.g., QKV projection) -// LoRA A: [rank, in_features] - replicated across TP ranks -// LoRA B: [out_features_per_partition, rank] - sharded like base weight -class LoRAColumnParallelLinear; - -// LoRA for RowParallelLinear (e.g., output projection) -// LoRA A: [rank, in_features_per_partition] - sharded like base weight -// LoRA B: [out_features, rank] - replicated across TP ranks -class LoRARowParallelLinear; -``` - -**注意**: 使用张量并行时无需手动创建这些类,`GetLoRAModel` 会自动处理。 - -### TP=1 自动退化 - -`ColumnParallelLinear` 和 `RowParallelLinear` 在 TP=1 时会自动退化为普通 Linear,无需在模型代码中条件分支: - -```cpp -// 模型代码可以统一使用 ColumnParallelLinear/RowParallelLinear -// TP=1 时自动走 fast-path,等价于普通 Linear -modules_["c_attn"] = std::make_shared(...); -modules_["c_proj"] = std::make_shared(...); -``` - -这使得 LoRA 包装可以统一工作,无论是否使用张量并行。 - -## 使用示例 - -### 示例 1: GPT2 微调 - -```cpp -#include "example/gpt2/gpt2.h" -#include "nn/lora/lora_utils.h" - -using namespace infini_train::nn::lora; -int main() { - // 创建 GPT2 模型 - auto model = std::make_shared(config); - model->LoadWeights("gpt2_weights.bin"); - - // 配置 LoRA - LoRAConfig lora_config; - lora_config.rank = 8; - lora_config.alpha = 16.0f; - lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); // 只对 attention 层 - - // 获取 LoRA 模型 (原地修改,自动冻结基础模型) - auto lora_model = GetLoRAModel(model, lora_config); - - // 打印参数统计 - PrintLoRASummary(lora_model); - // 输出示例: - // ========== LoRA Model Summary ========== - // Total parameters: 124,439,808 - // Trainable parameters: 294,912 (0.24%) - // Frozen parameters: 124,144,896 - // ========================================= - - // 创建优化器(只优化 LoRA 参数) - auto trainable_params = nn::lora::GetLoRAParameters(lora_model); - auto optimizer = std::make_shared(trainable_params, /*lr=*/1e-4); - - // 训练循环 - for (int step = 0; step < num_steps; ++step) { - auto [loss, logits] = (*lora_model)({input_ids}); - loss->Backward(); - optimizer->Step(); - optimizer->ZeroGrad(); - - if (step % 100 == 0) { - std::cout << "Step " << step << ", Loss: " << loss->Item() << std::endl; - } - } - - // 保存 LoRA 权重(仅几 MB) - SaveLoRAWeights(lora_model, "gpt2_lora.bin"); - - return 0; -} -``` - -### 示例 2: LLaMA3 分布式微调 - -```cpp -#include "example/llama3/llama3.h" -#include "nn/lora/lora_utils.h" -#include "nn/parallel/process_group.h" - -using namespace infini_train::nn::lora; - -int main(int argc, char **argv) { - // 初始化分布式环境 - InitDistributed(argc, argv); - - // 创建 LLaMA3 模型(带张量并行) - LLaMA3Config config; - config.n_layers = 32; - config.tensor_parallel = 2; - - auto model = std::make_shared(config); - model->LoadWeights("llama3_weights/"); - - // 配置 LoRA(包含 MLP 层以获得更好效果) - LoRAConfig lora_config{16, 32.0f}; - lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_fc2,c_proj"); - - // 获取 LoRA 模型(原地修改,自动冻结基础模型) - auto lora_model = GetLoRAModel(model, lora_config); - - PrintLoRASummary(lora_model); - - // 训练... - - // 保存 - if (GetRank() == 0) { - SaveLoRAWeights(lora_model, "llama3_lora.bin"); - } - - return 0; -} -``` - -### 示例 3: 推理时合并权重 - -```cpp -// 加载基础模型 -auto model = std::make_shared(config); -model->LoadWeights("gpt2_weights.bin"); - -// 配置并获取 LoRA 模型 -LoRAConfig lora_config; -lora_config.rank = 8; -lora_config.alpha = 16.0f; -lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); -auto lora_model = GetLoRAModel(model, lora_config); - -// 加载 LoRA 权重 -LoadLoRAWeights(lora_model, "gpt2_lora.bin"); - -// 合并权重(推理时无额外开销) -MergeLoRAWeights(lora_model); - -// 现在可以像普通模型一样推理 -auto output = (*lora_model)({input_ids}); - -// 如果需要继续训练,先解除合并 -UnmergeLoRAWeights(lora_model); -``` - -### 示例 4: 导出标准模型 (MergeAndUnload) - -使用 `MergeAndUnload` 将 LoRA 模型转换为标准模型,可以直接保存为普通模型文件: - -```cpp -// 加载基础模型并应用 LoRA -auto model = std::make_shared(config); -model->LoadWeights("gpt2_weights.bin"); - -LoRAConfig lora_config; -lora_config.rank = 8; -lora_config.alpha = 16.0f; -lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); -auto lora_model = GetLoRAModel(model, lora_config); - -// 训练... -// ... - -// 加载训练好的 LoRA 权重 -LoadLoRAWeights(lora_model, "gpt2_lora.bin"); - -// 合并并卸载 LoRA,返回标准模型 -// lora_model 中的所有 LoRALinear 都被替换为普通 Linear -auto merged_model = MergeAndUnload(lora_model); - -// 保存为标准模型(与原始模型格式相同) -merged_model->SaveWeights("gpt2_finetuned.bin"); - -// 现在 merged_model 是一个普通模型,无需 LoRA 即可推理 -auto output = (*merged_model)({input_ids}); -``` - -### 示例 5: 自定义目标层 - -```cpp -// 对所有线性层应用 -LoRAConfig config; -config.rank = 8; -config.alpha = 16.0f; -config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_proj,lm_head"); +void SaveLoRAWeights(const std::shared_ptr &model, + const std::string &filepath); -// 获取 LoRA 模型 -auto lora_model = GetLoRAModel(model, config); +void LoadLoRAWeights(std::shared_ptr model, + const std::string &filepath); ``` -## 最佳实践 - -### 1. 选择合适的 rank +- `MergeLoRAWeights` 将 `W += (alpha / rank) * B @ A` 写回基础权重,并冻结 LoRA 参数。 +- `UnmergeLoRAWeights` 撤销一次已合并的 LoRA 增量,并恢复 LoRA 参数可训练。 +- `MergeAndUnload` 会先合并,再把 LoRA 模块替换回普通 `Linear` / TP Linear,并解冻返回模型参数。 +- `SaveLoRAWeights` 只保存名字包含 `lora_A` / `lora_B` 的参数。 +- LoRA 权重文件是二进制文件,包含 magic `"LORA"`、version `1`、tensor 名称、维度和 float 数据。 +- 加载前模型必须已经用相同目标层注入 LoRA;找不到的 LoRA 参数会打印 warning。 +- 加载时如果文件里的 tensor 和当前参数同形状,会直接拷贝;如果当前参数是 TP 分片,会按第一个不同维度用 `parallel::tp_rank` 切片后再拷贝。 -| 任务类型 | 推荐 rank | 说明 | -|---------|----------|------| -| 简单分类任务 | 4-8 | 参数少,训练快 | -| 文本生成微调 | 8-16 | 平衡效果和效率 | -| 复杂任务适配 | 16-64 | 更强表达能力 | - -### 2. alpha 设置 - -- 通常设置 `alpha = 2 × rank` -- 较大的 alpha 会增加 LoRA 的影响 -- 可以通过调整 alpha 来控制微调强度 - -### 3. 目标层选择 +### 统计和解析工具 ```cpp -// 推荐:只对 attention 层(参数效率最高) -config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); +int64_t CountTrainableParameters(const std::shared_ptr &model); +int64_t CountTotalParameters(const std::shared_ptr &model); +void PrintLoRASummary(const std::shared_ptr &model, + int global_rank = -1); -// 可选:包含 MLP 层(效果可能更好,但参数更多) -config.target_modules = ParseLoRATargetModules("c_attn,c_proj,c_fc,c_fc2,c_proj"); +std::unordered_set +ParseLoRATargetModules(const std::string &targets); ``` -### 4. 学习率 - -- LoRA 通常使用比全量微调更高的学习率 -- 推荐范围:1e-4 到 1e-3 -- 可以使用学习率预热和衰减 - -### 5. 内存优化 - -```cpp -// 只保存 LoRA 权重(几 MB vs 几 GB) -SaveLoRAWeights(model, "lora.bin"); +`ParseLoRATargetModules("c_attn, attn.c_proj")` 会去掉空白并忽略空项。 -// 推理时合并权重,消除额外计算 -MergeLoRAWeights(model); -``` +## 支持的模块 -## 模型层名称参考 +`GetLoRAModel` 会自动替换匹配到的线性层,目前支持: -### GPT2 模型结构 +- `nn::Linear` +- `parallel::ColumnParallelLinear` +- `parallel::RowParallelLinear` -``` -transformer.wte # Token Embedding -transformer.wpe # Position Embedding -transformer.h.{i}.ln_1 # LayerNorm 1 -transformer.h.{i}.attn.c_attn # QKV 投影 (ColumnParallel) -transformer.h.{i}.attn.c_proj # Output 投影 (RowParallel) -transformer.h.{i}.ln_2 # LayerNorm 2 -transformer.h.{i}.mlp.c_fc # MLP 第一层 (ColumnParallel) -transformer.h.{i}.mlp.c_proj # MLP 第二层 (RowParallel) -transformer.ln_f # Final LayerNorm -lm_head # Language Model Head -``` +如果模型用了 TP,不需要手动创建 `LoRAColumnParallelLinear` 或 `LoRARowParallelLinear`; +只要正常调用 `GetLoRAModel(model, config)`,注入逻辑会根据基础层类型自动选择对应的 LoRA 实现。 -### LLaMA3 模型结构 +加载 LoRA 权重时,如果文件里保存的是完整 tensor,而当前模型在 TP 下只需要某个 rank 的分片,加载函数会自动按当前 `tp_rank` 切片。 -``` -transformer.tok_emb # Token Embedding -transformer.h.{i}.attn_norm # RMSNorm (attention) -transformer.h.{i}.attn.c_attn # QKV 投影 (ColumnParallel) -transformer.h.{i}.attn.c_proj # Output 投影 (RowParallel) -transformer.h.{i}.ffn_norm # RMSNorm (FFN) -transformer.h.{i}.mlp.c_fc # FFN gate (ColumnParallel) -transformer.h.{i}.mlp.c_fc2 # FFN up (ColumnParallel) -transformer.h.{i}.mlp.c_proj # FFN down (RowParallel) -transformer.norm # Final RMSNorm -lm_head # Language Model Head -``` +## 命令行示例 -## 命令行使用 (GPT2 示例) +GPT2 和 Llama3 示例都通过 `--lora_rank > 0` 启用 LoRA,并在训练结束时按需保存。 -### 启用 LoRA 训练 +### GPT2 ```bash ./build/gpt2 \ - --device cuda \ - --input_bin data/train.bin \ - --llmc_filepath data/gpt2_124M.bin \ - --batch_size 4 \ - --sequence_length 64 \ - --num_iteration 10 \ - --learning_rate 1e-5 \ - --lora_rank 8 \ - --lora_alpha 16.0 \ - --lora_target_modules "c_attn,c_proj" \ - --lora_save_path data/lora_weights + --device cuda \ + --input_bin data/train.bin \ + --llmc_filepath data/gpt2_124M.bin \ + --batch_size 4 \ + --sequence_length 64 \ + --num_iteration 10 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 16.0 \ + --lora_target_modules "c_attn,attn.c_proj" \ + --lora_save_path gpt2_lora.bin ``` -### 命令行参数 +GPT2 默认 LoRA target 是 `"c_attn,c_proj"`。 -| 参数 | 默认值 | 说明 | -|------|--------|------| -| `--lora_rank` | 0 | LoRA 秩 (0 = 禁用) | -| `--lora_alpha` | 16.0 | LoRA 缩放因子 | -| `--lora_target_modules` | "c_attn,c_proj" | 目标模块 (逗号分隔: c_attn,c_proj,c_fc,c_proj) | -| `--lora_load_path` | "" | 加载已有 LoRA 权重 | -| `--lora_save_path` | "" | 保存 LoRA 权重路径 | +### Llama3 + +```bash +./build/llama3 \ + --device cuda \ + --input_bin data/train.bin \ + --llmc_filepath data/llama3.bin \ + --batch_size 4 \ + --sequence_length 64 \ + --num_iteration 10 \ + --learning_rate 1e-4 \ + --lora_rank 8 \ + --lora_alpha 16.0 \ + --lora_target_modules "c_attn,c_proj,c_fc,c_fc2" \ + --lora_save_path llama3_lora.bin +``` + +Llama3 默认 LoRA target 是 `"c_attn,c_proj,c_fc,c_fc2"`。 ### 加载已有 LoRA 权重 ```bash ./build/gpt2 \ - ... - --lora_rank 8 \ - --lora_alpha 16.0 \ - --lora_load_path data/lora_weights + --device cuda \ + --input_bin data/train.bin \ + --llmc_filepath data/gpt2_124M.bin \ + --lora_rank 8 \ + --lora_alpha 16.0 \ + --lora_target_modules "c_attn,attn.c_proj" \ + --lora_load_path gpt2_lora.bin ``` -## LoRAModel 包装器 (推荐模式) - -### 概述 - -`LoRAModel` 是一个包装器类,遵循 PEFT 设计模式,将 LoRA 作为包装器应用于基础模型,而不是直接修改模型代码。 - -### 优势 +加载时请保持 `rank` 和 `lora_target_modules` 与保存这份 LoRA 权重时一致。 -- **透明性**: 训练循环无需修改,直接使用 `(*model)(inputs)` -- **参数管理**: 自动获取可训练参数 -- **权重管理**: 内置 Save/Load/Merge 方法 +### 参数表 -### 使用示例 +| 参数 | 默认值 | 说明 | +| --- | --- | --- | +| `--lora_rank` | `0` | LoRA rank;`0` 表示禁用 LoRA | +| `--lora_alpha` | `16.0` | LoRA scaling 分子,实际缩放为 `alpha / rank` | +| `--lora_target_modules` | 模型相关 | 逗号分隔的目标模块后缀 | +| `--lora_load_path` | `""` | 启动时加载 LoRA 权重文件 | +| `--lora_save_path` | `""` | 训练结束后保存 LoRA 权重文件 | -```cpp -#include "infini_train/include/nn/lora/lora_model.h" +示例入口在启用 LoRA 后会: -using namespace infini_train::nn::lora; +1. 构造 `LoRAConfig{rank, alpha, 0.0f, ParseLoRATargetModules(...)}`。 +2. 调用 `model = GetLoRAModel(model, lora_config)`。 +3. 如果设置 `--lora_load_path`,调用 `LoadLoRAWeights`。 +4. 使用 `GetLoRAParameters(model)` 构建优化器参数列表。 +5. 如果设置 `--lora_save_path`,训练结束后调用 `SaveLoRAWeights`。 -int main() { - // 1. 创建基础模型 - auto base_model = std::make_shared(config); - base_model->LoadWeights("gpt2_weights.bin"); - - // 2. 创建 LoRA 配置 - LoRAConfig lora_config{8, 16.0f}; - lora_config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); // 只对 attention 层 - - // 3. 创建 LoRA 包装器 (一行代码) - auto lora_model = std::make_shared(base_model, lora_config); - - // 4. 获取可训练参数用于优化器 - auto trainable_params = nn::lora::GetLoRAParameters(lora_model); - auto optimizer = std::make_shared(trainable_params, 1e-5); - - // 5. 打印摘要 - lora_model->PrintSummary(); - // 输出: - // ========== LoRA Model Summary ========== - // Total parameters: 176062464 - // Trainable parameters: 442368 (0.251256%) - // Frozen parameters: 175620096 - // ========================================= - - // 6. 训练循环 (无需修改) - for (int step = 0; step < num_steps; ++step) { - auto logits = (*lora_model)({x, y})[0]; - auto loss = (*loss_fn)({logits, y})[0]; - loss->Backward(); - optimizer->Step(); - optimizer->ZeroGrad(); - } - - // 7. 保存 LoRA 权重 - lora_model->SaveLoRA("lora_weights.bin"); - - return 0; -} -``` +## 目标模块建议 -### 工厂函数 - -对于任意模型类型,可以使用模板工厂函数: +常用配置: ```cpp -#include "infini_train/include/nn/lora/lora_model.h" - -auto lora_model = CreateLoRAModel( - model_config, // GPT2 模型配置 - lora_config // LoRA 配置 -); -``` - -### 两种使用方式的区别 - -| 特性 | `GetLoRAModel` | `LoRAModel` 包装器 | -|------|---------------|-------------------| -| 返回类型 | `std::shared_ptr` | `std::shared_ptr` | -| 修改方式 | 原地修改模型 | 创建新包装器 | -| 自动冻结 | 是 | 是 | -| 适用场景 | 简单场景,直接修改原模型 | 需要更精细控制 | +// GPT2 attention QKV + attention output projection. +config.target_modules = ParseLoRATargetModules("c_attn,attn.c_proj"); -### 推荐场景 +// GPT2/Llama3 覆盖所有名为 c_proj 的层,包括 attention 和 MLP output。 +config.target_modules = ParseLoRATargetModules("c_attn,c_proj"); -- **使用 `GetLoRAModel`**: 想要最小化代码改动,直接在原模型上启用 LoRA -- **使用 `LoRAModel`**: 需要更灵活的 API(如 `Merge()`/`Unmerge()` 方法),或者需要保留原始模型的引用 +// 包含 MLP 层。 +config.target_modules = ParseLoRATargetModules("c_attn,attn.c_proj,c_fc,c_fc2"); +``` -## 常见问题 +注意 `c_proj` 是宽匹配,`attn.c_proj` 是更精确匹配。若只想命中 attention output +projection,不要只写 `c_proj`。 -### Q: LoRA 权重文件有多大? +## 测试 -A: 取决于 rank 和目标层数量。以 GPT2-small (12层) 为例: -- rank=8, attention only: ~1.2 MB -- rank=16, attention + MLP: ~4.8 MB +LoRA 单元测试位于 `tests/lora/test_lora.cc`,覆盖: -### Q: 如何在不同任务间切换 LoRA? +- `LoRAConfig::Scaling` 和 target 后缀匹配。 +- `LoRALinear` 注入、forward、merge/unmerge。 +- `GetLoRAParameters`、参数统计、freeze/unfreeze。 +- `LoRAStateDict` 和 `MergeAndUnload`。 -A: 保存和加载不同的 LoRA 权重文件: -```cpp -// 任务 A -LoadLoRAWeights(model, "task_a_lora.bin"); -// 推理... +构建并运行: -// 任务 B -LoadLoRAWeights(model, "task_b_lora.bin"); -// 推理... +```bash +cmake -S . -B build -DBUILD_TEST=ON -DUSE_CUDA=ON -DUSE_NCCL=ON +cmake --build build -j +ctest --test-dir build -R LoRATest --output-on-failure ``` -### Q: 可以同时使用多个 LoRA 吗? +`scripts/test_config.json` 还包含 `lora` 测试组,覆盖 fp32/bfloat16、LoRA 权重加载,以及 DP、TP、SP、PP 组合。 + +## 常见注意事项 -A: 当前实现不支持多 LoRA 组合。如需此功能,可以: -1. 合并多个 LoRA 权重后加载 -2. 扩展实现支持 LoRA 堆叠 +- 直接用 `GetLoRAModel` 就好;现在没有单独的 `LoRAModel` 类,也没有 `lora_model.h`。 +- `LoRAConfig` 先默认创建再填字段;如果想一行构造,需要传完整的 4 个参数。 +- 保存和加载传的是具体文件名,比如 `gpt2_lora.bin`,不是目录。 +- LoRA 文件只存 `lora_A` / `lora_B`,不会保存基础模型权重。 +- TP 运行时加载 LoRA 权重,如果文件里是完整 tensor、当前 rank 只需要分片,加载函数会自动按当前 `tp_rank` 切一段。 +- `MergeLoRAWeights` 适合推理前用;如果还要继续训练,先调用 `UnmergeLoRAWeights`。 +- `MergeAndUnload` 会把 LoRA 模块变回普通 Linear,之后模型里就没有 `lora_A` / `lora_B` 了。 +- 现在一次只支持一套 LoRA adapter,不支持多个 adapter 叠加。 diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 74bf80c6..9d2d251e 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -52,6 +52,15 @@ class ProcessGroup { function::ReduceOpType reduce_op = function::ReduceOpType::kSum, bool async_op = false) const; + // root_rank_in_group is ProcessGroup-local rank. Broadcast updates tensors in place. + virtual std::shared_ptr Broadcast(const std::vector> &tensors, int root_rank_in_group, + bool async_op = false) const; + + // Root provides rank-major input_tensors: rank * output_tensors.size() + tensor_index. + virtual std::shared_ptr Scatter(const std::vector> &output_tensors, + const std::vector> &input_tensors, + int root_rank_in_group, bool async_op = false) const; + virtual std::shared_ptr Send(std::vector> tensors, int dest_rank, bool async_op = false) const; @@ -60,13 +69,13 @@ class ProcessGroup { // Legacy communication APIs (Single-stream) virtual std::vector> - BroadCast(const std::vector> &input_tensors) const; + BroadCast_(const std::vector> &input_tensors) const; virtual std::vector> ReduceAddCoalesced(const std::vector>> &grads, Device destination) const; - virtual std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const; + virtual std::vector> Scatter_(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const; virtual std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) const; diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index d524088a..df40f916 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -19,7 +19,7 @@ std::vector> Scatter::Forward(const std::vector> output_tensors; auto device = input->GetDevice().type(); - output_tensors = pg_->Scatter(input, target_gpus_, dim_); + output_tensors = pg_->Scatter_(input, target_gpus_, dim_); return output_tensors; } @@ -83,7 +83,7 @@ std::vector> Broadcast::Forward(const std::vectorBroadCast(input_tensors); + return pg_->BroadCast_(input_tensors); } void Broadcast::SetupContext(const std::vector> &input_tensors, diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3c4c4910..174aa645 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -194,6 +194,116 @@ std::shared_ptr ProcessGroup::ReduceScatter(const std::shared_ptr } } +std::shared_ptr ProcessGroup::Broadcast(const std::vector> &tensors, + int root_rank_in_group, bool async_op) const { + CHECK_GE(root_rank_in_group, 0); + CHECK_LT(root_rank_in_group, world_size_); + CHECK_GT(tensors.size(), 0); + CHECK_NOTNULL(tensors[0]); + + auto device = tensors[0]->GetDevice(); + auto group_rank = GetGroupRank(device.Rank().GlobalRank()); + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + for (const auto &tensor : tensors) { + CHECK_NOTNULL(tensor); + CHECK_EQ(device, tensor->GetDevice()); + const void *send_buffer = (group_rank == root_rank_in_group) ? tensor->DataPtr() : nullptr; + ccl_impl_->Broadcast(send_buffer, tensor->DataPtr(), tensor->NumElements(), tensor->Dtype(), root_rank_in_group, + comm, comm_stream); + } + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + +std::shared_ptr ProcessGroup::Scatter(const std::vector> &output_tensors, + const std::vector> &input_tensors, + int root_rank_in_group, bool async_op) const { + CHECK_GE(root_rank_in_group, 0); + CHECK_LT(root_rank_in_group, world_size_); + CHECK_GT(output_tensors.size(), 0); + CHECK_NOTNULL(output_tensors[0]); + + auto device = output_tensors[0]->GetDevice(); + auto group_rank = GetGroupRank(device.Rank().GlobalRank()); + core::DeviceGuard guard(device); + auto *compute_stream = runtime_impl_->GetStream(device); + auto *comm_stream = device_stream_map_.at(device.index()); + auto comm = device_comm_map_.at(device.index()); + + for (const auto &output_tensor : output_tensors) { + CHECK_NOTNULL(output_tensor); + CHECK_EQ(device, output_tensor->GetDevice()); + } + + const bool is_root = group_rank == root_rank_in_group; + const size_t num_outputs = output_tensors.size(); + if (is_root) { + CHECK_EQ(input_tensors.size(), static_cast(world_size_) * num_outputs) + << "Root rank must provide rank-major input tensors for every rank."; + for (const auto &input_tensor : input_tensors) { + CHECK_NOTNULL(input_tensor); + CHECK_EQ(device, input_tensor->GetDevice()); + } + for (size_t tensor_idx = 0; tensor_idx < num_outputs; ++tensor_idx) { + const auto &local_input = input_tensors[static_cast(group_rank) * num_outputs + tensor_idx]; + CHECK(local_input->Dtype() == output_tensors[tensor_idx]->Dtype()); + CHECK(local_input->Dims() == output_tensors[tensor_idx]->Dims()); + } + } else { + CHECK(input_tensors.empty()) << "Only root rank should provide scatter input tensors."; + } + + auto work = std::make_shared(device, comm); + runtime_impl_->EventRecord(work->ready_event(), compute_stream); + runtime_impl_->StreamWaitEvent(comm_stream, work->ready_event(), 0); + if (is_root) { + for (size_t tensor_idx = 0; tensor_idx < num_outputs; ++tensor_idx) { + const auto &input_tensor = input_tensors[static_cast(group_rank) * num_outputs + tensor_idx]; + runtime_impl_->MemcpyAsync(output_tensors[tensor_idx]->DataPtr(), input_tensor->DataPtr(), + input_tensor->SizeInBytes(), core::MemcpyKind::kD2D, comm_stream); + } + + core::CclGroupGuard ccl_group_guard(device.type()); + for (int rank = 0; rank < world_size_; ++rank) { + if (rank == group_rank) { + continue; + } + for (size_t tensor_idx = 0; tensor_idx < num_outputs; ++tensor_idx) { + const auto &input_tensor = input_tensors[static_cast(rank) * num_outputs + tensor_idx]; + ccl_impl_->Send(input_tensor->DataPtr(), input_tensor->NumElements(), input_tensor->Dtype(), rank, comm, + comm_stream); + } + } + } else { + core::CclGroupGuard ccl_group_guard(device.type()); + for (const auto &output_tensor : output_tensors) { + ccl_impl_->Recv(output_tensor->DataPtr(), output_tensor->NumElements(), output_tensor->Dtype(), + root_rank_in_group, comm, comm_stream); + } + } + runtime_impl_->EventRecord(work->done_event(), comm_stream); + + if (async_op) { + return work; + } else { + work->WaitNonBlocking(); + return nullptr; + } +} + std::shared_ptr ProcessGroup::Send(std::vector> tensors, int dest_rank, bool async_op) const { CHECK_GT(tensors.size(), 0); @@ -249,7 +359,7 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te } std::vector> -ProcessGroup::BroadCast(const std::vector> &input_tensors) const { +ProcessGroup::BroadCast_(const std::vector> &input_tensors) const { std::vector> outputs; std::vector streams; std::vector comms; @@ -329,8 +439,8 @@ ProcessGroup::ReduceAddCoalesced(const std::vector> ProcessGroup::Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const { +std::vector> ProcessGroup::Scatter_(const std::shared_ptr &tensor, + std::vector devices, int64_t dim) const { std::vector> outputs; auto split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim); std::vector streams;