From 9b3fefcb8ce988c11991eacf32e2b374c42e83d1 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Mon, 11 May 2026 11:32:03 +0000 Subject: [PATCH 01/11] feat: implement MoE infrastructure --- infini_train/include/autograd/moe.h | 26 + .../nn/modules/transformer/moe/experts.h | 25 + .../nn/modules/transformer/moe/moe_layer.h | 25 + .../nn/modules/transformer/moe/moe_utils.h | 9 + .../nn/modules/transformer/moe/router.h | 25 + .../modules/transformer/transformer_config.h | 33 + infini_train/src/autograd/moe.cc | 31 + infini_train/src/kernels/cpu/top1_mask.cc | 67 ++ infini_train/src/kernels/cuda/top1_mask.cu | 107 ++++ .../src/nn/modules/transformer/moe/experts.cc | 50 ++ .../nn/modules/transformer/moe/moe_layer.cc | 32 + .../nn/modules/transformer/moe/moe_utils.cc | 12 + .../src/nn/modules/transformer/moe/router.cc | 50 ++ .../src/nn/modules/transformer/transformer.cc | 7 +- .../test_transformer_architecture.cc | 600 ++++++++++++++++++ 15 files changed, 1098 insertions(+), 1 deletion(-) create mode 100644 infini_train/include/autograd/moe.h create mode 100644 infini_train/include/nn/modules/transformer/moe/experts.h create mode 100644 infini_train/include/nn/modules/transformer/moe/moe_layer.h create mode 100644 infini_train/include/nn/modules/transformer/moe/moe_utils.h create mode 100644 infini_train/include/nn/modules/transformer/moe/router.h create mode 100644 infini_train/src/autograd/moe.cc create mode 100644 infini_train/src/kernels/cpu/top1_mask.cc create mode 100644 infini_train/src/kernels/cuda/top1_mask.cu create mode 100644 infini_train/src/nn/modules/transformer/moe/experts.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/moe_layer.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/moe_utils.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/router.cc create mode 100644 test/transformer/test_transformer_architecture.cc diff --git a/infini_train/include/autograd/moe.h b/infini_train/include/autograd/moe.h new file mode 100644 index 00000000..5317de8e --- /dev/null +++ b/infini_train/include/autograd/moe.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Top1Mask : public Function { +public: + static constexpr char kType[] = "Top1MaskFunction"; + + Top1Mask() : Function(kType) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/moe/experts.h b/infini_train/include/nn/modules/transformer/moe/experts.h new file mode 100644 index 00000000..a3dda7f0 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/experts.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class SequentialMLP : public CloneableModule { +public: + static constexpr char kType[] = "SequentialMLP"; + static constexpr char kExpertNamePrefix[] = "expert_"; + + explicit SequentialMLP(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_layer.h b/infini_train/include/nn/modules/transformer/moe/moe_layer.h new file mode 100644 index 00000000..e5fdb3ab --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_layer.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class MoELayer : public CloneableModule { +public: + static constexpr char kType[] = "MoELayer"; + static constexpr char kRouterLayerName[] = "router"; + static constexpr char kExpertsLayerName[] = "experts"; + + explicit MoELayer(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h new file mode 100644 index 00000000..e0dd3744 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config); + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/router.h b/infini_train/include/nn/modules/transformer/moe/router.h new file mode 100644 index 00000000..1279c217 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/router.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn::moe { + +class TopKRouter : public CloneableModule { +public: + static constexpr char kType[] = "TopKRouter"; + static constexpr char kParamWeightName[] = "weight"; + static constexpr char kParamBiasName[] = "bias"; + + explicit TopKRouter(const TransformerConfig &config); + + std::vector> Forward(const std::vector> &input_tensors) override; + +private: + TransformerConfig config_; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 62379666..b55ce4fc 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -20,11 +20,42 @@ enum class MLPType { kSwiGLU // SwiGLU activation }; +enum class FFNType { + kDense, // Standard dense MLP + kMoE // Mixture-of-Experts MLP +}; + enum class NormType { kLayerNorm, // LayerNorm kRMSNorm // RMSNorm }; +enum class MoERouterType { + kTopK // Top-k router. The initial implementation supports top-1. +}; + +enum class MoEDispatcherType { + kLocal, // No cross-rank token exchange + kAllGather // Reserved for expert parallel MoE +}; + +enum class MoEExpertImpl { + kSequential // Run local experts sequentially +}; + +struct MoEConfig { + int64_t num_experts = 0; + int64_t expert_parallel_size = 1; + int64_t router_topk = 1; + float aux_loss_coeff = 0.0f; + std::optional expert_capacity_factor = std::nullopt; + bool pad_expert_input_to_capacity = false; + int64_t moe_ffn_hidden_size = 0; + MoERouterType router_type = MoERouterType::kTopK; + MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal; + MoEExpertImpl expert_impl = MoEExpertImpl::kSequential; +}; + struct TransformerConfig { int64_t block_size = 1024; // Max seq_len int64_t vocab_size = 50304; // Vocab size @@ -36,6 +67,7 @@ struct TransformerConfig { AttentionType attention_type = AttentionType::kStandard; // Attention mechanism type MLPType activation_type = MLPType::kGELU; // MLP activation type + FFNType ffn_type = FFNType::kDense; // Feed-forward module type NormType norm_type = NormType::kLayerNorm; // Normalization type bool add_bias_linear = true; // Whether to add learnable bias to all Linear layers in the Transformer block, @@ -48,6 +80,7 @@ struct TransformerConfig { float ffn_expansion_ratio = 4.0f; // MLP output: n_embd * ffn_expansion_ratio std::optional ffn_dim_multiplier = 1.5f; // FFN dim multiplier int64_t multiple_of = 256; // FFN dims must be multiple of this number + std::optional moe_config = std::nullopt; // RoPE config float rope_theta = 500000.0f; // theta in RoPE diff --git a/infini_train/src/autograd/moe.cc b/infini_train/src/autograd/moe.cc new file mode 100644 index 00000000..05134e82 --- /dev/null +++ b/infini_train/src/autograd/moe.cc @@ -0,0 +1,31 @@ +#include "infini_train/include/autograd/moe.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> Top1Mask::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "Top1MaskForward"}, input)}; +} + +void Top1Mask::SetupContext(const std::vector> &, + const std::vector> &output_tensors) { + saved_tensors_ = {output_tensors[0]}; +} + +std::vector> Top1Mask::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &mask_values = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "Top1MaskBackward"}, grad_output, mask_values)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/top1_mask.cc b/infini_train/src/kernels/cpu/top1_mask.cc new file mode 100644 index 00000000..d6ae91d6 --- /dev/null +++ b/infini_train/src/kernels/cpu/top1_mask.cc @@ -0,0 +1,67 @@ +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + output->Fill(0.0f); + + const float *in = static_cast(input->DataPtr()); + float *out = static_cast(output->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + int64_t best_idx = 0; + float best_value = in[row * num_experts]; + for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + const float value = in[row * num_experts + expert_idx]; + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + out[row * num_experts + best_idx] = best_value; + } + + return output; +} + +std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only"; + CHECK(mask_values->Dtype() == DataType::kFLOAT32); + CHECK(grad_output->Dims() == mask_values->Dims()); + + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + grad_input->Fill(0.0f); + + const float *grad = static_cast(grad_output->DataPtr()); + const float *mask = static_cast(mask_values->DataPtr()); + float *out = static_cast(grad_input->DataPtr()); + for (int64_t i = 0; i < static_cast(grad_output->NumElements()); ++i) { + out[i] = mask[i] != 0.0f ? grad[i] : 0.0f; + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward) +REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward) + +#undef REGISTER_CPU_TOP1_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/top1_mask.cu b/infini_train/src/kernels/cuda/top1_mask.cu new file mode 100644 index 00000000..8fd00c91 --- /dev/null +++ b/infini_train/src/kernels/cuda/top1_mask.cu @@ -0,0 +1,107 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, + int64_t num_experts) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t offset = row * num_experts; + int64_t best_idx = 0; + float best_value = static_cast(input[offset]); + for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + const float value = static_cast(input[offset + expert_idx]); + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f); + } +} + +std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { + CHECK_GE(input->Dims().size(), 1); + const auto &dims = input->Dims(); + const int64_t num_experts = dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = input->NumElements() / num_experts; + + auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + Top1MaskForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts); + }, + "CUDA Top1MaskForward"); + + return output; +} + +template +__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, + T *__restrict__ grad_input, int64_t total_elements) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_elements) { + return; + } + grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); +} + +std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &mask_values) { + CHECK(grad_output->Dims() == mask_values->Dims()); + CHECK(grad_output->Dtype() == mask_values->Dtype()); + auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int64_t total_elements = grad_output->NumElements(); + const int threads = 256; + const int blocks = static_cast((total_elements + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + grad_output->Dtype(), + [=]() { + Top1MaskBackwardKernel<<>>( + static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), + static_cast(grad_input->DataPtr()), total_elements); + }, + "CUDA Top1MaskBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward) +REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward) + +#undef REGISTER_CUDA_TOP1_MASK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc new file mode 100644 index 00000000..8f3b1be8 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/experts.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential); + CHECK_EQ(moe_config.expert_parallel_size, 1) + << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + num_local_experts_ = moe_config.num_experts; + CHECK_GT(num_local_experts_, 0); + + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + modules_[std::string(kExpertNamePrefix) + std::to_string(expert_idx)] = std::make_shared(config_); + } +} + +std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + auto hidden_states = input_tensors[0]; + auto routing_probs = input_tensors[1]; + CHECK_EQ(routing_probs->Dims().back(), num_local_experts_); + + std::shared_ptr output = nullptr; + const int64_t expert_dim = static_cast(routing_probs->Dims().size()) - 1; + for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); + auto expert_output = (*modules_.at(expert_name))({hidden_states})[0]; + auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1); + auto weighted_output = expert_output * expert_prob; + output = output == nullptr ? weighted_output : output + weighted_output; + } + + return {output}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc new file mode 100644 index 00000000..8efd51c0 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -0,0 +1,32 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/moe/experts.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(config_.ffn_type == FFNType::kMoE); + CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) + << "Current InfiniTrain MoE implementation supports local dispatch only"; + + modules_[kRouterLayerName] = std::make_shared(config_); + modules_[kExpertsLayerName] = std::make_shared(config_); +} + +std::vector> MoELayer::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + auto hidden_states = input_tensors[0]; + auto routing_probs = (*modules_.at(kRouterLayerName))({hidden_states})[0]; + return (*modules_.at(kExpertsLayerName))({hidden_states, routing_probs}); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc new file mode 100644 index 00000000..80ef01c1 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -0,0 +1,12 @@ +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" + +#include "glog/logging.h" + +namespace infini_train::nn::moe { + +const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { + CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; + return config.moe_config.value(); +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc new file mode 100644 index 00000000..59dec209 --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -0,0 +1,50 @@ +#include "infini_train/include/nn/modules/transformer/moe/router.h" + +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/autograd/linear.h" +#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/nn/functional.h" +#include "infini_train/include/nn/init.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { + const auto &moe_config = RequireMoEConfig(config_); + CHECK(moe_config.router_type == MoERouterType::kTopK); + CHECK_EQ(moe_config.router_topk, 1) << "Current InfiniTrain MoE implementation supports top-1 routing only"; + CHECK_GT(moe_config.num_experts, 0); + + parameters_[kParamWeightName] + = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, + device_) + ->RequiresGrad(); + init::KaimingUniform(parameters_[kParamWeightName]); + + if (config_.add_bias_linear) { + parameters_[kParamBiasName] + = std::make_shared(std::vector{moe_config.num_experts}, DataType::kFLOAT32, device_) + ->RequiresGrad(); + parameters_[kParamBiasName]->Fill(0.0f); + } +} + +std::vector> TopKRouter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + std::vector> linear_inputs{input_tensors[0], parameters_.at(kParamWeightName)}; + if (parameters_.contains(kParamBiasName)) { + linear_inputs.push_back(parameters_.at(kParamBiasName)); + } + + auto logits = std::make_shared()->Apply(linear_inputs)[0]; + auto scores = function::Softmax(logits, -1); + auto routing_probs = std::make_shared()->Apply({scores})[0]; + return {routing_probs}; +} + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..bdcde449 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" #include "infini_train/include/nn/modules/transformer/utils.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" @@ -86,7 +87,11 @@ TransformerLayer::TransformerLayer(const nn::TransformerConfig &config) : Clonea } modules_[kAttnLayerName] = std::make_shared(config); - modules_[kMlpLayerName] = std::make_shared(config); + if (config.ffn_type == FFNType::kMoE) { + modules_[kMlpLayerName] = std::make_shared(config); + } else { + modules_[kMlpLayerName] = std::make_shared(config); + } } std::vector> TransformerLayer::Forward(const std::vector> &x) { diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc new file mode 100644 index 00000000..da3dd70e --- /dev/null +++ b/test/transformer/test_transformer_architecture.cc @@ -0,0 +1,600 @@ +#include +#include +#include + +#include "glog/logging.h" + +#include "example/gpt2/config.h" +#include "example/llama3/config.h" +#include "infini_train/include/nn/modules/activations.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/sparse.h" +#include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/nn/modules/transformer/utils.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +// ============================================================================ +// Test 1: TransformerConfig Validation +// ============================================================================ +void TestConfigValidation() { + std::cout << "\n=== Test 1: TransformerConfig Validation ===" << std::endl; + + bool all_passed = true; + + // Test GPT2 config + auto gpt2_config = gpt2::GPT2Config(); + if (gpt2_config.attention_type != nn::AttentionType::kStandard) { + std::cout << "FAIL: GPT2 config should use Standard attention" << std::endl; + all_passed = false; + } + if (gpt2_config.activation_type != nn::MLPType::kGELU) { + std::cout << "FAIL: GPT2 config should use GELU activation" << std::endl; + all_passed = false; + } + if (gpt2_config.norm_type != nn::NormType::kLayerNorm) { + std::cout << "FAIL: GPT2 config should use LayerNorm" << std::endl; + all_passed = false; + } + if (!gpt2_config.add_bias_linear) { + std::cout << "FAIL: GPT2 config should have bias enabled" << std::endl; + all_passed = false; + } + if (!gpt2_config.tie_weights) { + std::cout << "FAIL: GPT2 config should have tied weights" << std::endl; + all_passed = false; + } + + // Test LLaMA3 config + auto llama3_config = llama3::LLaMA3Config(); + if (llama3_config.attention_type != nn::AttentionType::kRoPE) { + std::cout << "FAIL: LLaMA3 config should use RoPE attention" << std::endl; + all_passed = false; + } + if (llama3_config.activation_type != nn::MLPType::kSwiGLU) { + std::cout << "FAIL: LLaMA3 config should use SwiGLU activation" << std::endl; + all_passed = false; + } + if (llama3_config.norm_type != nn::NormType::kRMSNorm) { + std::cout << "FAIL: LLaMA3 config should use RMSNorm" << std::endl; + all_passed = false; + } + if (llama3_config.add_bias_linear) { + std::cout << "FAIL: LLaMA3 config should have bias disabled" << std::endl; + all_passed = false; + } + if (llama3_config.tie_weights) { + std::cout << "FAIL: LLaMA3 config should not have tied weights" << std::endl; + all_passed = false; + } + + // Test GQA detection + if (!llama3_config.UseGQA()) { + std::cout << "FAIL: LLaMA3 config should detect GQA (n_kv_head < n_head)" << std::endl; + all_passed = false; + } + if (gpt2_config.UseGQA()) { + std::cout << "FAIL: GPT2 config should not detect GQA (n_kv_head == n_head)" << std::endl; + all_passed = false; + } + + if (all_passed) { + std::cout << "SUCCESS: All config validations passed!" << std::endl; + } +} + +// ============================================================================ +// Test 2: Embedding Layer +// ============================================================================ +void TestEmbedding() { + std::cout << "\n=== Test 2: Embedding Layer ===" << std::endl; + + const int64_t vocab_size = 1000; + const int64_t embedding_dim = 128; + const int64_t batch_size = 2; + const int64_t seq_len = 16; + + try { + auto embedding = std::make_shared(vocab_size, embedding_dim); + + // Check parameters + auto params = embedding->Parameters(); + if (params.size() != 1) { + std::cout << "FAIL: Embedding should have 1 parameter, got " << params.size() << std::endl; + return; + } + + // Check weight shape + auto weight = embedding->parameter(nn::Embedding::kParamWeightName); + if (weight->Dims() != std::vector{vocab_size, embedding_dim}) { + std::cout << "FAIL: Embedding weight shape mismatch" << std::endl; + return; + } + + // Forward pass + auto input = std::make_shared(std::vector{batch_size, seq_len}, DataType::kINT64); + auto output = (*embedding)({input}); + + if (output.size() != 1) { + std::cout << "FAIL: Embedding forward should return 1 tensor" << std::endl; + return; + } + + const auto &out_dims = output[0]->Dims(); + if (out_dims != std::vector{batch_size, seq_len, embedding_dim}) { + std::cout << "FAIL: Embedding output shape mismatch. Expected [" << batch_size << ", " << seq_len << ", " + << embedding_dim << "], got [" << out_dims[0] << ", " << out_dims[1] << ", " << out_dims[2] << "]" + << std::endl; + return; + } + + std::cout << "SUCCESS: Embedding layer works correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 3: Normalization Layers (LayerNorm vs RMSNorm) +// ============================================================================ +void TestNormalization() { + std::cout << "\n=== Test 3: Normalization Layers ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + + try { + // Test LayerNorm + auto layernorm = std::make_shared(std::vector{hidden_size}); + auto ln_params = layernorm->Parameters(); + if (ln_params.size() != 2) { + std::cout << "FAIL: LayerNorm should have 2 parameters (weight, bias), got " << ln_params.size() + << std::endl; + return; + } + + // Test RMSNorm + auto rmsnorm = std::make_shared(hidden_size); + auto rms_params = rmsnorm->Parameters(); + if (rms_params.size() != 1) { + std::cout << "FAIL: RMSNorm should have 1 parameter (weight), got " << rms_params.size() << std::endl; + return; + } + + // Forward pass for both + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto ln_output = (*layernorm)({input}); + auto rms_output = (*rmsnorm)({input}); + + if (ln_output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: LayerNorm output shape mismatch" << std::endl; + return; + } + + if (rms_output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: RMSNorm output shape mismatch" << std::endl; + return; + } + + std::cout << "SUCCESS: Normalization layers work correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 4: MLP Layer (GELU vs SwiGLU) +// ============================================================================ +void TestMlp() { + std::cout << "\n=== Test 4: MLP Layer ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + + try { + // Test GPT2-style MLP (GELU) + nn::TransformerConfig gpt2_mlp_config; + gpt2_mlp_config.n_embd = hidden_size; + gpt2_mlp_config.activation_type = nn::MLPType::kGELU; + gpt2_mlp_config.ffn_expansion_ratio = 4.0f; + gpt2_mlp_config.add_bias_linear = true; + + auto gpt2_mlp = std::make_shared(gpt2_mlp_config); + auto gpt2_params = gpt2_mlp->Parameters(); + + // GPT2 MLP should have: c_fc.weight, c_fc.bias, c_proj.weight, c_proj.bias + if (gpt2_params.size() != 4) { + std::cout << "FAIL: GPT2 MLP should have 4 parameters, got " << gpt2_params.size() << std::endl; + return; + } + + // Test LLaMA3-style MLP (SwiGLU) + nn::TransformerConfig llama3_mlp_config; + llama3_mlp_config.n_embd = hidden_size; + llama3_mlp_config.activation_type = nn::MLPType::kSwiGLU; + llama3_mlp_config.ffn_expansion_ratio = 4.0f; + llama3_mlp_config.add_bias_linear = false; + llama3_mlp_config.ffn_dim_multiplier = 1.5f; + llama3_mlp_config.multiple_of = 256; + + auto llama3_mlp = std::make_shared(llama3_mlp_config); + auto llama3_params = llama3_mlp->Parameters(); + + // LLaMA3 MLP should have: c_fc.weight, c_fc2.weight, c_proj.weight (no bias) + if (llama3_params.size() != 3) { + std::cout << "FAIL: LLaMA3 MLP should have 3 parameters, got " << llama3_params.size() << std::endl; + return; + } + + // Forward pass + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto gpt2_output = (*gpt2_mlp)({input}); + auto llama3_output = (*llama3_mlp)({input}); + + // Output should have same hidden dimension + if (gpt2_output[0]->Dims()[2] != hidden_size) { + std::cout << "FAIL: GPT2 MLP output hidden dim mismatch" << std::endl; + return; + } + + if (llama3_output[0]->Dims()[2] != hidden_size) { + std::cout << "FAIL: LLaMA3 MLP output hidden dim mismatch" << std::endl; + return; + } + + std::cout << "SUCCESS: MLP layers work correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 5: CausalSelfAttention +// ============================================================================ +void TestAttention() { + std::cout << "\n=== Test 5: CausalSelfAttention ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + const int64_t n_head = 4; + + try { + // Test standard attention (GPT2-style) + nn::TransformerConfig standard_config; + standard_config.n_embd = hidden_size; + standard_config.n_head = n_head; + standard_config.n_kv_head = n_head; + standard_config.attention_type = nn::AttentionType::kStandard; + standard_config.add_bias_linear = true; + + auto standard_attn = std::make_shared(standard_config); + auto standard_params = standard_attn->Parameters(); + + // Should have c_attn (QKV combined) and c_proj with biases + if (standard_params.size() != 4) { + std::cout << "FAIL: Standard attention should have 4 parameters, got " << standard_params.size() + << std::endl; + return; + } + + // Test RoPE attention with GQA (LLaMA3-style) + nn::TransformerConfig rope_config; + rope_config.n_embd = hidden_size; + rope_config.n_head = n_head; + rope_config.n_kv_head = 2; // GQA: fewer KV heads + rope_config.attention_type = nn::AttentionType::kRoPE; + rope_config.add_bias_linear = false; + + auto rope_attn = std::make_shared(rope_config); + auto rope_params = rope_attn->Parameters(); + + // RoPE attention without bias should have fewer params + if (rope_params.empty()) { + std::cout << "FAIL: RoPE attention should have parameters" << std::endl; + return; + } + + // Forward pass + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto standard_output = (*standard_attn)({input}); + if (standard_output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: Standard attention output shape mismatch" << std::endl; + return; + } + + std::cout << "SUCCESS: CausalSelfAttention works correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 6: TransformerLayer +// ============================================================================ +void TestTransformerLayer() { + std::cout << "\n=== Test 6: TransformerLayer ===" << std::endl; + + const int64_t hidden_size = 64; + const int64_t batch_size = 2; + const int64_t seq_len = 8; + + try { + // Test GPT2-style layer + auto gpt2_config = gpt2::GPT2Config(); + gpt2_config.n_embd = hidden_size; + gpt2_config.n_head = 4; + gpt2_config.n_layer = 1; + + auto gpt2_layer = std::make_shared(gpt2_config); + auto gpt2_params = gpt2_layer->Parameters(); + + if (gpt2_params.empty()) { + std::cout << "FAIL: GPT2 TransformerLayer should have parameters" << std::endl; + return; + } + + // Forward pass + auto input + = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); + + auto output = (*gpt2_layer)({input}); + if (output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: TransformerLayer output shape mismatch" << std::endl; + return; + } + + // Test LLaMA3-style layer + auto llama3_config = llama3::LLaMA3Config(); + llama3_config.n_embd = hidden_size; + llama3_config.n_head = 4; + llama3_config.n_kv_head = 2; + llama3_config.n_layer = 1; + + auto llama3_layer = std::make_shared(llama3_config); + auto llama3_params = llama3_layer->Parameters(); + + if (llama3_params.empty()) { + std::cout << "FAIL: LLaMA3 TransformerLayer should have parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: TransformerLayer works correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 7: TransformerModel Instantiation (GPT2) +// ============================================================================ +void TestGpt2Model() { + std::cout << "\n=== Test 7: GPT2 Model Instantiation ===" << std::endl; + + auto config = gpt2::GPT2Config(); + // Use smaller config for faster testing + config.n_layer = 2; + config.n_head = 4; + config.n_embd = 64; + + try { + auto model = std::make_shared(config); + + if (model == nullptr) { + std::cout << "FAIL: Failed to create GPT2 model" << std::endl; + return; + } + + auto params = model->Parameters(); + if (params.empty()) { + std::cout << "FAIL: GPT2 model has no parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: GPT2 model created with " << params.size() << " parameters!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 8: TransformerModel Instantiation (LLaMA3) +// ============================================================================ +void TestLlama3Model() { + std::cout << "\n=== Test 8: LLaMA3 Model Instantiation ===" << std::endl; + + auto config = llama3::LLaMA3Config(); + // Use smaller config for faster testing + config.n_layer = 2; + config.n_head = 4; + config.n_kv_head = 2; + config.n_embd = 64; + + try { + auto model = std::make_shared(config); + + if (model == nullptr) { + std::cout << "FAIL: Failed to create LLaMA3 model" << std::endl; + return; + } + + auto params = model->Parameters(); + if (params.empty()) { + std::cout << "FAIL: LLaMA3 model has no parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: LLaMA3 model created with " << params.size() << " parameters!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 9: RoPE Utilities +// ============================================================================ +void TestRopeUtils() { + std::cout << "\n=== Test 9: RoPE Utilities ===" << std::endl; + + const int64_t head_dim = 64; + const int64_t seq_len = 128; + + try { + // Test precompute freqs_cis + auto freqs_cis = PrecomputeFreqsCis(head_dim, seq_len); + + // freqs_cis shape: [seq_len, head_dim/2, 2] (cos and sin stacked on last dim) + const auto &dims = freqs_cis->Dims(); + if (dims.size() != 3) { + std::cout << "FAIL: freqs_cis should be 3D, got " << dims.size() << "D" << std::endl; + return; + } + if (dims[0] != seq_len) { + std::cout << "FAIL: freqs_cis seq_len mismatch. Expected " << seq_len << ", got " << dims[0] << std::endl; + return; + } + if (dims[1] != head_dim / 2) { + std::cout << "FAIL: freqs_cis head_dim/2 mismatch. Expected " << head_dim / 2 << ", got " << dims[1] + << std::endl; + return; + } + if (dims[2] != 2) { + std::cout << "FAIL: freqs_cis last dim should be 2 (cos, sin), got " << dims[2] << std::endl; + return; + } + + std::cout << "SUCCESS: RoPE utilities work correctly!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 10: Model StateDict +// ============================================================================ +void TestStateDict() { + std::cout << "\n=== Test 10: Model StateDict ===" << std::endl; + + nn::TransformerConfig config; + config.n_layer = 1; + config.n_head = 2; + config.n_kv_head = 2; // Must set explicitly + config.n_embd = 32; + config.vocab_size = 1000; + config.attention_type = nn::AttentionType::kStandard; + config.activation_type = nn::MLPType::kGELU; + config.norm_type = nn::NormType::kLayerNorm; + config.add_bias_linear = true; + + try { + auto model = std::make_shared(config); + auto state_dict = model->StateDict(); + + if (state_dict.empty()) { + std::cout << "FAIL: StateDict should not be empty" << std::endl; + return; + } + + // StateDict includes both parameters and buffers, so it should have >= parameters count + auto params = model->Parameters(); + auto buffers = model->Buffers(); + + if (state_dict.size() < params.size()) { + std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should be >= parameter count (" + << params.size() << ")" << std::endl; + return; + } + + // Expected: state_dict.size() == params.size() + buffers.size() + size_t expected_size = params.size() + buffers.size(); + if (state_dict.size() != expected_size) { + std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should equal params (" << params.size() + << ") + buffers (" << buffers.size() << ") = " << expected_size << std::endl; + return; + } + + std::cout << "SUCCESS: StateDict works correctly with " << state_dict.size() << " entries (" << params.size() + << " params + " << buffers.size() << " buffers)!" << std::endl; + + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Test 11: MoE Layer MVP +// ============================================================================ +void TestMoELayer() { + std::cout << "\n=== Test 11: MoE Layer MVP ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 2; + config.moe_config->router_topk = 1; + + try { + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + if (output.size() != 1) { + std::cout << "FAIL: MoELayer forward should return 1 tensor" << std::endl; + return; + } + if (output[0]->Dims() != input->Dims()) { + std::cout << "FAIL: MoELayer output shape mismatch" << std::endl; + return; + } + + auto params = moe->Parameters(); + if (params.empty()) { + std::cout << "FAIL: MoELayer should own router and expert parameters" << std::endl; + return; + } + + std::cout << "SUCCESS: MoE layer MVP forward works correctly!" << std::endl; + } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } +} + +// ============================================================================ +// Main +// ============================================================================ +int main(int argc, char *argv[]) { + google::InitGoogleLogging(argv[0]); + + nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); + + std::cout << "========================================" << std::endl; + std::cout << " Transformer architecture Tests" << std::endl; + std::cout << "========================================" << std::endl; + + TestConfigValidation(); + TestEmbedding(); + TestNormalization(); + TestMlp(); + TestAttention(); + TestTransformerLayer(); + TestGpt2Model(); + TestLlama3Model(); + TestRopeUtils(); + TestStateDict(); + TestMoELayer(); + + std::cout << "\n========================================" << std::endl; + std::cout << " All Tests Completed" << std::endl; + std::cout << "========================================" << std::endl; + + return 0; +} From ea6af08e18f37b356f2c9811fb6621169f9d2675 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 13 May 2026 03:01:31 +0000 Subject: [PATCH 02/11] feat: support topk_router --- .../include/autograd/{moe.h => topk_mask.h} | 9 ++- .../modules/transformer/transformer_config.h | 2 +- .../src/autograd/{moe.cc => topk_mask.cc} | 13 ++-- .../cpu/{top1_mask.cc => topk_mask.cc} | 53 ++++++++++++----- .../cuda/{top1_mask.cu => topk_mask.cu} | 55 ++++++++++------- .../src/nn/modules/transformer/moe/router.cc | 8 ++- .../test_transformer_architecture.cc | 59 ++++++++++++------- 7 files changed, 126 insertions(+), 73 deletions(-) rename infini_train/include/autograd/{moe.h => topk_mask.h} (76%) rename infini_train/src/autograd/{moe.cc => topk_mask.cc} (70%) rename infini_train/src/kernels/cpu/{top1_mask.cc => topk_mask.cc} (50%) rename infini_train/src/kernels/cuda/{top1_mask.cu => topk_mask.cu} (66%) diff --git a/infini_train/include/autograd/moe.h b/infini_train/include/autograd/topk_mask.h similarity index 76% rename from infini_train/include/autograd/moe.h rename to infini_train/include/autograd/topk_mask.h index 5317de8e..355ef400 100644 --- a/infini_train/include/autograd/moe.h +++ b/infini_train/include/autograd/topk_mask.h @@ -11,16 +11,19 @@ class Tensor; namespace infini_train::autograd { -class Top1Mask : public Function { +class TopKMask : public Function { public: - static constexpr char kType[] = "Top1MaskFunction"; + static constexpr char kType[] = "TopKMaskFunction"; - Top1Mask() : Function(kType) {} + explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {} std::vector> Forward(const std::vector> &input_tensors) override; void SetupContext(const std::vector> &input_tensors, const std::vector> &output_tensors) override; std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t topk_ = 1; }; } // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index b55ce4fc..3a96625d 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -31,7 +31,7 @@ enum class NormType { }; enum class MoERouterType { - kTopK // Top-k router. The initial implementation supports top-1. + kTopK // Top-k router. }; enum class MoEDispatcherType { diff --git a/infini_train/src/autograd/moe.cc b/infini_train/src/autograd/topk_mask.cc similarity index 70% rename from infini_train/src/autograd/moe.cc rename to infini_train/src/autograd/topk_mask.cc index 05134e82..16dc6629 100644 --- a/infini_train/src/autograd/moe.cc +++ b/infini_train/src/autograd/topk_mask.cc @@ -1,4 +1,4 @@ -#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/autograd/topk_mask.h" #include "glog/logging.h" @@ -7,25 +7,26 @@ namespace infini_train::autograd { -std::vector> Top1Mask::Forward(const std::vector> &input_tensors) { +std::vector> TopKMask::Forward(const std::vector> &input_tensors) { CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); const auto &input = input_tensors[0]; auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "Top1MaskForward"}, input)}; + return {Dispatcher::Instance().Call>({device, "TopKMaskForward"}, input, topk_)}; } -void Top1Mask::SetupContext(const std::vector> &, +void TopKMask::SetupContext(const std::vector> &, const std::vector> &output_tensors) { saved_tensors_ = {output_tensors[0]}; } -std::vector> Top1Mask::Backward(const std::vector> &grad_outputs) { +std::vector> TopKMask::Backward(const std::vector> &grad_outputs) { CHECK_EQ(grad_outputs.size(), 1); const auto &grad_output = grad_outputs[0]; const auto &mask_values = saved_tensors_[0]; auto device = grad_output->GetDevice().type(); return { - Dispatcher::Instance().Call>({device, "Top1MaskBackward"}, grad_output, mask_values)}; + Dispatcher::Instance().Call>({device, "TopKMaskBackward"}, grad_output, mask_values)}; } } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/top1_mask.cc b/infini_train/src/kernels/cpu/topk_mask.cc similarity index 50% rename from infini_train/src/kernels/cpu/top1_mask.cc rename to infini_train/src/kernels/cpu/topk_mask.cc index d6ae91d6..6a7191b9 100644 --- a/infini_train/src/kernels/cpu/top1_mask.cc +++ b/infini_train/src/kernels/cpu/topk_mask.cc @@ -1,4 +1,6 @@ +#include #include +#include #include "glog/logging.h" @@ -7,13 +9,15 @@ namespace infini_train::kernels::cpu { -std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { - CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskForward currently supports float32 only"; +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only"; CHECK_GE(input->Dims().size(), 1); const auto &dims = input->Dims(); const int64_t num_experts = dims.back(); CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); const int64_t rows = input->NumElements() / num_experts; auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); @@ -22,24 +26,41 @@ std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { const float *in = static_cast(input->DataPtr()); float *out = static_cast(output->DataPtr()); for (int64_t row = 0; row < rows; ++row) { - int64_t best_idx = 0; - float best_value = in[row * num_experts]; - for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { - const float value = in[row * num_experts + expert_idx]; - if (value > best_value) { - best_value = value; - best_idx = expert_idx; + const int64_t row_offset = row * num_experts; + std::vector selected_experts(num_experts, false); + float selected_sum = 0.0f; + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value = -std::numeric_limits::infinity(); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (selected_experts[expert_idx]) { + continue; + } + const float value = in[row_offset + expert_idx]; + if (value > best_value) { + best_value = value; + best_idx = expert_idx; + } + } + CHECK_GE(best_idx, 0); + selected_experts[best_idx] = true; + out[row_offset + best_idx] = best_value; + selected_sum += best_value; + } + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + out[row_offset + expert_idx] + = out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum; } } - out[row * num_experts + best_idx] = best_value; } return output; } -std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, const std::shared_ptr &mask_values) { - CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU Top1MaskBackward currently supports float32 only"; + CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only"; CHECK(mask_values->Dtype() == DataType::kFLOAT32); CHECK(grad_output->Dims() == mask_values->Dims()); @@ -58,10 +79,10 @@ std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_out } // namespace infini_train::kernels::cpu -#define REGISTER_CPU_TOP1_MASK_KERNEL(kernel_name) \ +#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) -REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskForward) -REGISTER_CPU_TOP1_MASK_KERNEL(Top1MaskBackward) +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward) -#undef REGISTER_CPU_TOP1_MASK_KERNEL +#undef REGISTER_CPU_TOPK_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/top1_mask.cu b/infini_train/src/kernels/cuda/topk_mask.cu similarity index 66% rename from infini_train/src/kernels/cuda/top1_mask.cu rename to infini_train/src/kernels/cuda/topk_mask.cu index 8fd00c91..e38c793e 100644 --- a/infini_train/src/kernels/cuda/top1_mask.cu +++ b/infini_train/src/kernels/cuda/topk_mask.cu @@ -11,33 +11,44 @@ namespace infini_train::kernels::cuda { template -__global__ void Top1MaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, - int64_t num_experts) { +__global__ void TopKMaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, + int64_t num_experts, int64_t topk) { int64_t row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= rows) { return; } const int64_t offset = row * num_experts; - int64_t best_idx = 0; - float best_value = static_cast(input[offset]); - for (int64_t expert_idx = 1; expert_idx < num_experts; ++expert_idx) { + float selected_sum = 0.0f; + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { const float value = static_cast(input[offset + expert_idx]); - if (value > best_value) { - best_value = value; - best_idx = expert_idx; + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < num_experts; ++other_idx) { + const float other_value = static_cast(input[offset + other_idx]); + if (other_value > value || (other_value == value && other_idx < expert_idx)) { + ++rank; + } } + const bool selected = rank < topk; + output[offset + expert_idx] = selected ? input[offset + expert_idx] : T(0.0f); + selected_sum += selected ? value : 0.0f; } - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - output[offset + expert_idx] = expert_idx == best_idx ? input[offset + expert_idx] : T(0.0f); + if (topk > 1 && selected_sum != 0.0f) { + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + if (static_cast(output[offset + expert_idx]) != 0.0f) { + output[offset + expert_idx] = T(static_cast(output[offset + expert_idx]) / selected_sum); + } + } } } -std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { +std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { CHECK_GE(input->Dims().size(), 1); const auto &dims = input->Dims(); const int64_t num_experts = dims.back(); CHECK_GT(num_experts, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, num_experts); const int64_t rows = input->NumElements() / num_experts; auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); @@ -52,16 +63,16 @@ std::shared_ptr Top1MaskForward(const std::shared_ptr &input) { core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { - Top1MaskForwardKernel<<>>( - static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts); + TopKMaskForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts, topk); }, - "CUDA Top1MaskForward"); + "CUDA TopKMaskForward"); return output; } template -__global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, +__global__ void TopKMaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, T *__restrict__ grad_input, int64_t total_elements) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_elements) { @@ -70,7 +81,7 @@ __global__ void Top1MaskBackwardKernel(const T *__restrict__ grad_output, const grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); } -std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_output, +std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, const std::shared_ptr &mask_values) { CHECK(grad_output->Dims() == mask_values->Dims()); CHECK(grad_output->Dtype() == mask_values->Dtype()); @@ -87,21 +98,21 @@ std::shared_ptr Top1MaskBackward(const std::shared_ptr &grad_out core::cuda::DispatchCudaFunc( grad_output->Dtype(), [=]() { - Top1MaskBackwardKernel<<>>( + TopKMaskBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), static_cast(grad_input->DataPtr()), total_elements); }, - "CUDA Top1MaskBackward"); + "CUDA TopKMaskBackward"); return grad_input; } } // namespace infini_train::kernels::cuda -#define REGISTER_CUDA_TOP1_MASK_KERNEL(kernel_name) \ +#define REGISTER_CUDA_TOPK_MASK_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) -REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskForward) -REGISTER_CUDA_TOP1_MASK_KERNEL(Top1MaskBackward) +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskForward) +REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskBackward) -#undef REGISTER_CUDA_TOP1_MASK_KERNEL +#undef REGISTER_CUDA_TOPK_MASK_KERNEL diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc index 59dec209..851c57be 100644 --- a/infini_train/src/nn/modules/transformer/moe/router.cc +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -6,7 +6,7 @@ #include "glog/logging.h" #include "infini_train/include/autograd/linear.h" -#include "infini_train/include/autograd/moe.h" +#include "infini_train/include/autograd/topk_mask.h" #include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" @@ -17,8 +17,9 @@ namespace infini_train::nn::moe { TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); CHECK(moe_config.router_type == MoERouterType::kTopK); - CHECK_EQ(moe_config.router_topk, 1) << "Current InfiniTrain MoE implementation supports top-1 routing only"; CHECK_GT(moe_config.num_experts, 0); + CHECK_GT(moe_config.router_topk, 0); + CHECK_LE(moe_config.router_topk, moe_config.num_experts); parameters_[kParamWeightName] = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, @@ -43,7 +44,8 @@ std::vector> TopKRouter::Forward(const std::vector()->Apply(linear_inputs)[0]; auto scores = function::Softmax(logits, -1); - auto routing_probs = std::make_shared()->Apply({scores})[0]; + const auto &moe_config = RequireMoEConfig(config_); + auto routing_probs = std::make_shared(moe_config.router_topk)->Apply({scores})[0]; return {routing_probs}; } diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index da3dd70e..469ff386 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -527,10 +527,10 @@ void TestStateDict() { } // ============================================================================ -// Test 11: MoE Layer MVP +// Test 11: MoE Layer // ============================================================================ void TestMoELayer() { - std::cout << "\n=== Test 11: MoE Layer MVP ===" << std::endl; + std::cout << "\n=== Test 11: MoE Layer ===" << std::endl; nn::TransformerConfig config; config.n_embd = 32; @@ -543,29 +543,43 @@ void TestMoELayer() { config.moe_config->num_experts = 2; config.moe_config->router_topk = 1; - try { - auto moe = std::make_shared(config); - auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); - input->Uniform(); + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); - auto output = (*moe)({input}); - if (output.size() != 1) { - std::cout << "FAIL: MoELayer forward should return 1 tensor" << std::endl; - return; - } - if (output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: MoELayer output shape mismatch" << std::endl; - return; - } + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); - auto params = moe->Parameters(); - if (params.empty()) { - std::cout << "FAIL: MoELayer should own router and expert parameters" << std::endl; - return; - } + auto params = moe->Parameters(); + CHECK(!params.empty()); - std::cout << "SUCCESS: MoE layer MVP forward works correctly!" << std::endl; - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } + std::cout << "SUCCESS: MoE layer forward works correctly!" << std::endl; +} + +void TestMoELayerTop2() { + std::cout << "\n=== Test 12: MoE Layer Top-2 ===" << std::endl; + + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + + auto moe = std::make_shared(config); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); + input->Uniform(); + + auto output = (*moe)({input}); + CHECK_EQ(output.size(), 1); + CHECK(output[0]->Dims() == input->Dims()); + + std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; } // ============================================================================ @@ -591,6 +605,7 @@ int main(int argc, char *argv[]) { TestRopeUtils(); TestStateDict(); TestMoELayer(); + TestMoELayerTop2(); std::cout << "\n========================================" << std::endl; std::cout << " All Tests Completed" << std::endl; From bfb59245fce78ddad59c792d6aff0fa3becc35a1 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 13 May 2026 07:36:52 +0000 Subject: [PATCH 03/11] feat: support moe_ffn_hidden_size config --- infini_train/src/nn/modules/transformer/mlp.cc | 7 ++++++- test/transformer/test_transformer_architecture.cc | 13 +++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/infini_train/src/nn/modules/transformer/mlp.cc b/infini_train/src/nn/modules/transformer/mlp.cc index 3af341b2..ac35d144 100644 --- a/infini_train/src/nn/modules/transformer/mlp.cc +++ b/infini_train/src/nn/modules/transformer/mlp.cc @@ -35,9 +35,14 @@ MLP::MLP(const TransformerConfig &config) : CloneableModule(kType) { } // Round up to multiple_of - int64_t before_round = ffn_hidden; ffn_hidden = (ffn_hidden + config.multiple_of - 1) / config.multiple_of * config.multiple_of; + if (config.ffn_type == FFNType::kMoE && config.moe_config.has_value() + && config.moe_config->moe_ffn_hidden_size > 0) { + ffn_hidden = config.moe_config->moe_ffn_hidden_size; + } + CHECK_GT(ffn_hidden, 0); + // c_fc: ColumnParallel (input full, output parallel) modules_[kCFcLayerName] = std::make_shared( /*in_features=*/config.n_embd, /*out_features=*/ffn_hidden, diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc index 469ff386..42efda2d 100644 --- a/test/transformer/test_transformer_architecture.cc +++ b/test/transformer/test_transformer_architecture.cc @@ -564,12 +564,13 @@ void TestMoELayerTop2() { config.n_embd = 32; config.n_head = 2; config.n_kv_head = 2; - config.activation_type = nn::MLPType::kGELU; - config.add_bias_linear = true; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; config.ffn_type = nn::FFNType::kMoE; config.moe_config = nn::MoEConfig{}; config.moe_config->num_experts = 4; config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; auto moe = std::make_shared(config); auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); @@ -579,6 +580,14 @@ void TestMoELayerTop2() { CHECK_EQ(output.size(), 1); CHECK(output[0]->Dims() == input->Dims()); + auto state = moe->StateDict(); + CHECK(state.contains("experts.expert_0.c_fc.weight")); + CHECK(state.contains("experts.expert_0.c_fc2.weight")); + CHECK(state.contains("experts.expert_0.c_proj.weight")); + CHECK(state.at("experts.expert_0.c_fc.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_fc2.weight")->Dims() == std::vector({48, config.n_embd})); + CHECK(state.at("experts.expert_0.c_proj.weight")->Dims() == std::vector({config.n_embd, 48})); + std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; } From fc361c7e2408762bc6a488ecb2e19694cdd88a99 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 26 May 2026 11:24:26 +0000 Subject: [PATCH 04/11] feat: add bool datatype --- infini_train/include/core/backend_type_map.h | 3 +++ infini_train/include/datatype.h | 3 +++ infini_train/include/dtype_dispatch.h | 5 +++- infini_train/src/kernels/cpu/cast.cc | 3 ++- infini_train/src/kernels/cpu/fill.cc | 2 +- infini_train/src/kernels/cuda/cast.cu | 3 ++- infini_train/src/kernels/cuda/concat.cu | 4 +-- infini_train/src/kernels/cuda/elementwise.cu | 28 ++++++++++---------- infini_train/src/kernels/cuda/fill.cu | 2 +- infini_train/src/kernels/cuda/slice.cu | 4 +-- infini_train/src/kernels/cuda/split.cu | 4 +-- infini_train/src/kernels/cuda/stack.cu | 4 +-- infini_train/src/kernels/cuda/transform.cu | 22 +++++++-------- infini_train/src/utils/precision_checker.cc | 2 ++ 14 files changed, 51 insertions(+), 38 deletions(-) diff --git a/infini_train/include/core/backend_type_map.h b/infini_train/include/core/backend_type_map.h index f67b8da7..38fea110 100644 --- a/infini_train/include/core/backend_type_map.h +++ b/infini_train/include/core/backend_type_map.h @@ -48,6 +48,9 @@ template struct BackendTypeMap; // ----------------------------------------------------------------------------- #define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \ namespace infini_train::core { \ + template <> struct BackendTypeMap { \ + using type = bool; \ + }; \ template <> struct BackendTypeMap { \ using type = uint8_t; \ }; \ diff --git a/infini_train/include/datatype.h b/infini_train/include/datatype.h index cf637300..0ae8fda7 100644 --- a/infini_train/include/datatype.h +++ b/infini_train/include/datatype.h @@ -84,6 +84,7 @@ struct alignas(2) BF16 { // DataType enum and metadata tables // ----------------------------------------------------------------------------- enum class DataType : int8_t { + kBOOL, kUINT8, kINT8, kUINT16, @@ -99,12 +100,14 @@ enum class DataType : int8_t { }; inline const std::unordered_map kDataTypeToSize = { + {DataType::kBOOL, 1}, {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2}, {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8}, {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8}, }; inline const std::unordered_map kDataTypeToDesc = { + {DataType::kBOOL, "bool"}, {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"}, {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"}, {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"}, diff --git a/infini_train/include/dtype_dispatch.h b/infini_train/include/dtype_dispatch.h index e3db38b8..8bd5054b 100644 --- a/infini_train/include/dtype_dispatch.h +++ b/infini_train/include/dtype_dispatch.h @@ -180,10 +180,11 @@ namespace infini_train { #define INFINI_FLOATING_TYPES DataType::kFLOAT32, DataType::kFLOAT64 #define INFINI_REDUCED_FLOATING_TYPES DataType::kFLOAT16, DataType::kBFLOAT16 #define INFINI_ALL_FLOATING_TYPES INFINI_FLOATING_TYPES, INFINI_REDUCED_FLOATING_TYPES +#define INFINI_LOGICAL_TYPES DataType::kBOOL #define INFINI_SIGNED_INTEGRAL_TYPES DataType::kINT8, DataType::kINT16, DataType::kINT32, DataType::kINT64 #define INFINI_UNSIGNED_INTEGRAL_TYPES DataType::kUINT8, DataType::kUINT16, DataType::kUINT32, DataType::kUINT64 #define INFINI_ALL_INTEGRAL_TYPES INFINI_SIGNED_INTEGRAL_TYPES, INFINI_UNSIGNED_INTEGRAL_TYPES -#define INFINI_ALL_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES +#define INFINI_ALL_NUMERIC_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES #define INFINI_8_BIT_TYPES DataType::kINT8, DataType::kUINT8 #define INFINI_16_BIT_TYPES DataType::kINT16, DataType::kUINT16, DataType::kFLOAT16, DataType::kBFLOAT16 #define INFINI_32_BIT_TYPES DataType::kINT32, DataType::kUINT32, DataType::kFLOAT32 @@ -242,6 +243,7 @@ auto DispatchByTypeMap(DataType dtype, Functor &&func, std::string_view context_ } \ } + CASE_FOR_TYPE(DataType::kBOOL) CASE_FOR_TYPE(DataType::kUINT8) CASE_FOR_TYPE(DataType::kINT8) CASE_FOR_TYPE(DataType::kUINT16) @@ -290,6 +292,7 @@ struct TypeMapDispatcher { break; \ } + CASE_FOR_TYPE(DataType::kBOOL) CASE_FOR_TYPE(DataType::kUINT8) CASE_FOR_TYPE(DataType::kINT8) CASE_FOR_TYPE(DataType::kUINT16) diff --git a/infini_train/src/kernels/cpu/cast.cc b/infini_train/src/kernels/cpu/cast.cc index 114a5597..c3a2e595 100644 --- a/infini_train/src/kernels/cpu/cast.cc +++ b/infini_train/src/kernels/cpu/cast.cc @@ -13,7 +13,8 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { auto device = input->GetDevice(); auto dst_tensor = std::make_shared(input->Dims(), dtype, device); - core::cpu::DispatchCpuFunc, DataTypeList>( + core::cpu::DispatchCpuFunc, + DataTypeList>( {dtype, input->Dtype()}, [=]() { auto dst = static_cast(dst_tensor->DataPtr()); diff --git a/infini_train/src/kernels/cpu/fill.cc b/infini_train/src/kernels/cpu/fill.cc index 5f8b7cd3..7bcda6bd 100644 --- a/infini_train/src/kernels/cpu/fill.cc +++ b/infini_train/src/kernels/cpu/fill.cc @@ -8,7 +8,7 @@ namespace infini_train::kernels::cpu { void Fill(std::shared_ptr tensor, Scalar scalar) { - core::cpu::DispatchCpuFunc( + core::cpu::DispatchCpuFunc( tensor->Dtype(), [=]() { auto data = reinterpret_cast(tensor->DataPtr()); diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index 16190912..96a70ae2 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -34,7 +34,8 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x)); const size_t step = grid_dims.x * block_dims.x; - core::cuda::DispatchCudaFunc, DataTypeList>( + core::cuda::DispatchCudaFunc, + DataTypeList>( {dtype, input->Dtype()}, [=]() { auto dst = static_cast(dst_tensor->DataPtr()); diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index c158a5c3..a7fa7490 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -103,7 +103,7 @@ std::shared_ptr ConcatForward(const std::vector> int threads_per_block = 256; int num_blocks = static_cast((total + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=, &inputs, &host_offsets]() { std::vector host_input_ptrs; @@ -208,7 +208,7 @@ std::vector> ConcatBackward(const std::shared_ptr((total + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=, &grads, &host_offsets]() { std::vector host_ptrs; diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index fe63e0b2..92ce9915 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -1018,7 +1018,7 @@ std::shared_ptr EqualsForward(const std::shared_ptr &a, const st DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x == y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr EqualsScalarForward(const std::shared_ptr &a, float scalar) { @@ -1033,7 +1033,7 @@ std::shared_ptr EqualsScalarForward(const std::shared_ptr &a, fl std::shared_ptr LtForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward( a, b, [] __device__(auto x, auto y) { return x < y ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LtScalarForward(const std::shared_ptr &a, float scalar) { @@ -1042,14 +1042,14 @@ std::shared_ptr LtScalarForward(const std::shared_ptr &a, float return (x < static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LeForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x <= y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LeScalarForward(const std::shared_ptr &a, float scalar) { @@ -1058,13 +1058,13 @@ std::shared_ptr LeScalarForward(const std::shared_ptr &a, float return (x <= static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GtForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward( a, b, [] __device__(auto x, auto y) { return x > y ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GtScalarForward(const std::shared_ptr &a, float scalar) { @@ -1073,14 +1073,14 @@ std::shared_ptr GtScalarForward(const std::shared_ptr &a, float return (x > static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GeForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x >= y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GeScalarForward(const std::shared_ptr &a, float scalar) { @@ -1089,7 +1089,7 @@ std::shared_ptr GeScalarForward(const std::shared_ptr &a, float return (x >= static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr OrForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1098,7 +1098,7 @@ std::shared_ptr OrForward(const std::shared_ptr &a, const std::s return (x != decltype(x){0} || y != decltype(y){0}) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AndForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1107,7 +1107,7 @@ std::shared_ptr AndForward(const std::shared_ptr &a, const std:: return (x != decltype(x){0} && y != decltype(y){0}) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AddForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1125,19 +1125,19 @@ std::pair, std::shared_ptr> AddBackward(const st std::shared_ptr AddScalarForward(const std::shared_ptr &a, float scalar) { DISPATCH(a->Dtype(), return UnaryForward(a, [scalar] __device__(auto x) { return Add(x, static_cast(scalar)); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AddScalarBackward(const std::shared_ptr &grad_output) { DISPATCH(grad_output->Dtype(), return UnaryBackward(grad_output, nullptr, [] __device__(auto x) { return common::cuda::Cast(1); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr SubForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return Sub(x, y); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::pair, std::shared_ptr> SubBackward(const std::shared_ptr &grad_output, diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index f5532779..3ddead5c 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -28,7 +28,7 @@ void Fill(std::shared_ptr tensor, Scalar scalar) { infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( tensor->Dtype(), [=]() { const T casted_value = scalar.to(); diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 35bd2ac5..d030d73a 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -92,7 +92,7 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { SliceForwardKernel<<>>( @@ -185,7 +185,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( grad_output_dtype, [=]() { SliceBackwardKernel<<>>( diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index f208695f..bda0dd70 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -59,7 +59,7 @@ std::vector> SplitForward(const std::shared_ptr infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { SplitForwardKernel<<>>( @@ -166,7 +166,7 @@ std::shared_ptr SplitBackward(const std::vector &input_dims, in CHECK_GE(dim, 0) << "Currently we do not support negative dimension"; CHECK_LT(dim, input_dims.size()); - return core::cuda::DispatchCudaFunc( + return core::cuda::DispatchCudaFunc( grad_outputs[0]->Dtype(), [=]() { return LaunchSplitBackward(input_dims, split_size, dim, grad_outputs); }, "CUDA SplitBackward"); diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 562fa5ec..841940ea 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -61,7 +61,7 @@ std::shared_ptr StackForward(const std::vector> int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { std::vector host_input_ptrs; @@ -129,7 +129,7 @@ std::vector> StackBackward(const std::vector &i int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { std::vector host_ptrs; diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 2bb35598..88f0e10f 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -47,7 +47,7 @@ std::shared_ptr TrilForward(const std::shared_ptr &input, int64_ infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { TrilForwardKernel<<>>( @@ -90,7 +90,7 @@ std::shared_ptr TrilBackward(const std::shared_ptr &grad_output, infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -135,7 +135,7 @@ std::shared_ptr TriuForward(const std::shared_ptr &input, int64_ infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { TriuForwardKernel<<>>( @@ -177,7 +177,7 @@ std::shared_ptr TriuBackward(const std::shared_ptr &grad_output, infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -269,7 +269,7 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { output->Fill(0.0); @@ -371,7 +371,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const int64_t inner = input->NumElements() / rows; int num_blocks = static_cast((input->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { MaskLeadsForwardKernel<<>>( @@ -384,7 +384,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const int64_t batch_size = input->NumElements() / mask_size; int num_blocks = static_cast((input->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { MaskForwardKernel<<>>( @@ -435,7 +435,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, int64_t inner = grad_output->NumElements() / rows; int num_blocks = static_cast((grad_output->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -449,7 +449,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, int64_t batch_size = grad_output->NumElements() / mask_size; int num_blocks = static_cast((grad_output->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -504,7 +504,7 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { RepeatInterleaveForwardKernel<<>>( @@ -562,7 +562,7 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( grad_output->Dtype(), [=]() { grad_input->Fill(0.0); diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index d2cbd16a..2965284e 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -193,6 +193,8 @@ std::string FormatShape(const std::vector &shape) { std::string DataTypeToString(DataType dtype) { switch (dtype) { + case DataType::kBOOL: + return "bool"; case DataType::kFLOAT32: return "float32"; case DataType::kFLOAT16: From dbdf5699b97ae64ccb1780fb34cb5f78b155cb2a Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 26 May 2026 11:40:49 +0000 Subject: [PATCH 05/11] feat: add scatter operator, distinguish tensor and communication APIs via namespaces, and reorganize functions in misc files --- infini_train/include/autograd/comm.h | 4 +- infini_train/include/autograd/indexing.h | 47 ++++++ infini_train/include/autograd/misc.h | 113 -------------- infini_train/include/autograd/no_op.h | 30 ++++ infini_train/include/autograd/scatter.h | 29 ++++ infini_train/include/autograd/transform.h | 49 ++++++ infini_train/src/autograd/comm.cc | 4 +- infini_train/src/autograd/indexing.cc | 63 ++++++++ infini_train/src/autograd/misc.cc | 147 ------------------ infini_train/src/autograd/no_op.cc | 31 ++++ infini_train/src/autograd/scatter.cc | 34 ++++ infini_train/src/autograd/transform.cc | 64 ++++++++ infini_train/src/kernels/cpu/gather.cc | 11 +- infini_train/src/kernels/cpu/scatter.cc | 87 +++++++++++ infini_train/src/kernels/cuda/comm.cu | 8 +- infini_train/src/kernels/cuda/gather.cu | 42 +++-- infini_train/src/kernels/cuda/scatter.cu | 120 ++++++++++++++ .../src/nn/parallel/parallel_functional.cc | 6 +- infini_train/src/tensor.cc | 5 +- 19 files changed, 590 insertions(+), 304 deletions(-) create mode 100644 infini_train/include/autograd/indexing.h delete mode 100644 infini_train/include/autograd/misc.h create mode 100644 infini_train/include/autograd/no_op.h create mode 100644 infini_train/include/autograd/scatter.h create mode 100644 infini_train/src/autograd/indexing.cc delete mode 100644 infini_train/src/autograd/misc.cc create mode 100644 infini_train/src/autograd/no_op.cc create mode 100644 infini_train/src/autograd/scatter.cc create mode 100644 infini_train/src/kernels/cpu/scatter.cc create mode 100644 infini_train/src/kernels/cuda/scatter.cu diff --git a/infini_train/include/autograd/comm.h b/infini_train/include/autograd/comm.h index ec3cfe4a..c67372ee 100644 --- a/infini_train/include/autograd/comm.h +++ b/infini_train/include/autograd/comm.h @@ -15,7 +15,7 @@ class ProcessGroup; } // namespace nn::parallel } // namespace infini_train -namespace infini_train::autograd { +namespace infini_train::autograd::comm { class Scatter : public autograd::Function { public: static constexpr char kType[] = "ScatterFunction"; @@ -99,4 +99,4 @@ class ReduceAddCoalesced : public autograd::Function { std::vector target_gpus_; int64_t num_inputs_ = 0; }; -} // namespace infini_train::autograd +} // namespace infini_train::autograd::comm diff --git a/infini_train/include/autograd/indexing.h b/infini_train/include/autograd/indexing.h new file mode 100644 index 00000000..f35321c8 --- /dev/null +++ b/infini_train/include/autograd/indexing.h @@ -0,0 +1,47 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Gather : public Function { +public: + static constexpr char kType[] = "GatherFunction"; + + Gather(int64_t dim = 0) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t dim_ = 0; + std::vector input_dims_; +}; + +class Slice : public Function { +public: + static constexpr char kType[] = "SliceFunction"; + + Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) + : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const std::vector starts_; + const std::vector ends_; + const std::vector steps_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/misc.h b/infini_train/include/autograd/misc.h deleted file mode 100644 index ccfca22d..00000000 --- a/infini_train/include/autograd/misc.h +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -#include -#include - -#include "infini_train/include/autograd/function.h" - -namespace infini_train { -class Tensor; -} - -namespace infini_train::autograd { -class Split : public Function { -public: - static constexpr char kType[] = "SplitFunction"; - - Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t split_size_ = 0; - const int dim_ = 0; - std::vector input_dims_; -}; - -// FIXME(zbl): This function aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency -class IndexGather : public Function { -public: - static constexpr char kType[] = "IndexGatherFunction"; - - IndexGather(int64_t dim = 0) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t dim_ = 0; - std::vector input_dims_; -}; - -class NoOp : public Function { -public: - static constexpr char kType[] = "NoOpFunction"; - - explicit NoOp(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const std::vector output_dims_; - std::vector input_dims_; -}; - -class Slice : public Function { -public: - static constexpr char kType[] = "SliceFunction"; - - Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) - : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const std::vector starts_; - const std::vector ends_; - const std::vector steps_; -}; - -class Stack : public Function { -public: - static constexpr char kType[] = "StackFunction"; - - Stack(int64_t dim) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - int64_t dim_ = 0; - std::vector input_dims_; -}; - -class Concat : public Function { -public: - static constexpr char kType[] = "ConcatFunction"; - - Concat(int64_t dim) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t dim_ = 0; - std::vector> input_dims_list_; -}; -} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/no_op.h b/infini_train/include/autograd/no_op.h new file mode 100644 index 00000000..a097393d --- /dev/null +++ b/infini_train/include/autograd/no_op.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class NoOp : public Function { +public: + static constexpr char kType[] = "NoOpFunction"; + + explicit NoOp(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const std::vector output_dims_; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/scatter.h b/infini_train/include/autograd/scatter.h new file mode 100644 index 00000000..3d6f830a --- /dev/null +++ b/infini_train/include/autograd/scatter.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Scatter : public Function { +public: + static constexpr char kType[] = "ScatterFunction"; + + explicit Scatter(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + std::vector output_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/transform.h b/infini_train/include/autograd/transform.h index 92ce71ea..e345061a 100644 --- a/infini_train/include/autograd/transform.h +++ b/infini_train/include/autograd/transform.h @@ -78,4 +78,53 @@ class RepeatInterleave : public Function { std::vector input_dims_; }; +class Split : public Function { +public: + static constexpr char kType[] = "SplitFunction"; + + Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t split_size_ = 0; + const int dim_ = 0; + std::vector input_dims_; +}; + +class Stack : public Function { +public: + static constexpr char kType[] = "StackFunction"; + + Stack(int64_t dim) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t dim_ = 0; + std::vector input_dims_; +}; + +class Concat : public Function { +public: + static constexpr char kType[] = "ConcatFunction"; + + Concat(int64_t dim) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t dim_ = 0; + std::vector> input_dims_list_; +}; + } // namespace infini_train::autograd diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index d524088a..325422b3 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -8,7 +8,7 @@ #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/tensor.h" -namespace infini_train::autograd { +namespace infini_train::autograd::comm { Scatter::Scatter(const std::vector &target_gpus, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) @@ -122,4 +122,4 @@ std::vector> ReduceAddCoalesced::Backward(const std::vector> &grad_outputs) { return std::make_shared(target_gpus_)->Apply(grad_outputs); } -} // namespace infini_train::autograd +} // namespace infini_train::autograd::comm diff --git a/infini_train/src/autograd/indexing.cc b/infini_train/src/autograd/indexing.cc new file mode 100644 index 00000000..a9642256 --- /dev/null +++ b/infini_train/src/autograd/indexing.cc @@ -0,0 +1,63 @@ +#include "infini_train/include/autograd/indexing.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { +std::vector> Gather::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &input = input_tensors[0]; + const auto &index = input_tensors[1]; + + auto device = input->GetDevice().type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "GatherForward"}); + return {kernel.Call>(input, index, dim_)}; +} + +void Gather::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + const auto &index = input_tensors[1]; + input_dims_ = input->Dims(); + saved_tensors_ = {index}; +} + +std::vector> Gather::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &index = saved_tensors_[0]; + + auto device = grad_outputs[0]->GetDevice(); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "GatherBackward"}); + return {kernel.Call>(grad_output, index, dim_, input_dims_), nullptr}; +} + +std::vector> Slice::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; +} + +void Slice::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + // FIXME(dcj): only input's dim need to be saved + const auto &input = input_tensors[0]; + saved_tensors_ = {input}; +} + +std::vector> Slice::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(saved_tensors_.size(), 1); + const auto &input = saved_tensors_[0]; + const auto &grad_output = grad_outputs[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, + ends_, steps_)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/misc.cc b/infini_train/src/autograd/misc.cc deleted file mode 100644 index 601258eb..00000000 --- a/infini_train/src/autograd/misc.cc +++ /dev/null @@ -1,147 +0,0 @@ -#include "infini_train/include/autograd/misc.h" - -#include "glog/logging.h" - -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -namespace infini_train::autograd { -std::vector> Split::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, - split_size_, dim_)}; -} - -void Split::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> Split::Backward(const std::vector> &grad_outputs) { - auto device = grad_outputs[0]->GetDevice(); - return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, - split_size_, dim_, grad_outputs)}; -} - -std::vector> IndexGather::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 2); - const auto &input = input_tensors[0]; - const auto &index = input_tensors[1]; - - auto device = input->GetDevice().type(); - auto kernel = Dispatcher::Instance().GetKernel({device, "IndexGatherForward"}); - return {kernel.Call>(input, index, dim_)}; -} - -void IndexGather::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - const auto &index = input_tensors[1]; - input_dims_ = input->Dims(); - saved_tensors_ = {index}; -} - -std::vector> IndexGather::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - const auto &index = saved_tensors_[0]; - - auto device = grad_outputs[0]->GetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device.type(), "IndexGatherBackward"}); - return {kernel.Call>(grad_output, index, dim_, input_dims_)}; -} - -std::vector> NoOp::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; -} - -void NoOp::SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> NoOp::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; -} - -std::vector> Slice::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return { - Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; -} - -void Slice::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - // FIXME(dcj): only input's dim need to be saved - const auto &input = input_tensors[0]; - saved_tensors_ = {input}; -} - -std::vector> Slice::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; - const auto &grad_output = grad_outputs[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, - ends_, steps_)}; -} - -std::vector> Stack::Forward(const std::vector> &input_tensors) { - CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice().type(); - - return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; -} - -void Stack::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> Stack::Backward(const std::vector> &grad_outputs) { - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, - dim_, grad_output)}; -} - -std::vector> Concat::Forward(const std::vector> &input_tensors) { - CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice().type(); - - auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); - return {kernel.Call>(input_tensors, dim_)}; -} - -void Concat::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - for (auto input : input_tensors) { input_dims_list_.push_back(input->Dims()); } -} - -std::vector> Concat::Backward(const std::vector> &grad_outputs) { - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); - return kernel.Call>>(grad_output, input_dims_list_, dim_); -} -} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/no_op.cc b/infini_train/src/autograd/no_op.cc new file mode 100644 index 00000000..b4247dec --- /dev/null +++ b/infini_train/src/autograd/no_op.cc @@ -0,0 +1,31 @@ +#include "infini_train/include/autograd/no_op.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { +std::vector> NoOp::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; +} + +void NoOp::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> NoOp::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/scatter.cc b/infini_train/src/autograd/scatter.cc new file mode 100644 index 00000000..472fd543 --- /dev/null +++ b/infini_train/src/autograd/scatter.cc @@ -0,0 +1,34 @@ +#include "infini_train/include/autograd/scatter.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> Scatter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &values = input_tensors[0]; + const auto &indices = input_tensors[1]; + auto device = values->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "ScatterForward"}, values, indices, + output_dims_)}; +} + +void Scatter::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + saved_tensors_ = {input_tensors[1]}; +} + +std::vector> Scatter::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &indices = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + auto grad_values + = Dispatcher::Instance().Call>({device, "ScatterBackward"}, grad_output, indices); + return {grad_values, nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/transform.cc b/infini_train/src/autograd/transform.cc index 4fae05bb..a85e9cd4 100644 --- a/infini_train/src/autograd/transform.cc +++ b/infini_train/src/autograd/transform.cc @@ -89,4 +89,68 @@ RepeatInterleave::Backward(const std::vector> &grad_outp return {Dispatcher::Instance().Call>({device, "RepeatInterleaveBackward"}, grad_output, input_dims_, dim_)}; } + +std::vector> Split::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, + split_size_, dim_)}; +} + +void Split::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> Split::Backward(const std::vector> &grad_outputs) { + auto device = grad_outputs[0]->GetDevice(); + return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, + split_size_, dim_, grad_outputs)}; +} + +std::vector> Stack::Forward(const std::vector> &input_tensors) { + CHECK_GE(input_tensors.size(), 2); + const auto device = input_tensors[0]->GetDevice().type(); + + return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; +} + +void Stack::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> Stack::Backward(const std::vector> &grad_outputs) { + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, + dim_, grad_output)}; +} + +std::vector> Concat::Forward(const std::vector> &input_tensors) { + CHECK_GE(input_tensors.size(), 2); + const auto device = input_tensors[0]->GetDevice().type(); + + auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); + return {kernel.Call>(input_tensors, dim_)}; +} + +void Concat::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + for (auto input : input_tensors) { input_dims_list_.push_back(input->Dims()); } +} + +std::vector> Concat::Backward(const std::vector> &grad_outputs) { + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); + return kernel.Call>>(grad_output, input_dims_list_, dim_); +} + } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/gather.cc b/infini_train/src/kernels/cpu/gather.cc index b59fd45f..bd28fea8 100644 --- a/infini_train/src/kernels/cpu/gather.cc +++ b/infini_train/src/kernels/cpu/gather.cc @@ -8,10 +8,7 @@ #include "infini_train/include/tensor.h" namespace infini_train::kernels::cpu { -// FIXME(zbl): This kernel aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency -std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const std::shared_ptr &index, +std::shared_ptr GatherForward(const std::shared_ptr &input, const std::shared_ptr &index, int64_t dim) { const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); @@ -103,7 +100,7 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, return out; } -std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_output, +std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, const std::shared_ptr &index, int64_t dim, const std::vector &input_dims) { const auto &in_dims = input_dims; @@ -199,7 +196,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ #define REGISTER_CPU_GATHER_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) -REGISTER_CPU_GATHER_KERNEL(IndexGatherForward) -REGISTER_CPU_GATHER_KERNEL(IndexGatherBackward) +REGISTER_CPU_GATHER_KERNEL(GatherForward) +REGISTER_CPU_GATHER_KERNEL(GatherBackward) #undef REGISTER_CPU_GATHER_KERNEL diff --git a/infini_train/src/kernels/cpu/scatter.cc b/infini_train/src/kernels/cpu/scatter.cc new file mode 100644 index 00000000..b98326ea --- /dev/null +++ b/infini_train/src/kernels/cpu/scatter.cc @@ -0,0 +1,87 @@ +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr ScatterForward(const std::shared_ptr &values, + const std::shared_ptr &indices, + const std::vector &output_dims) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterForward expects int64 indices"; + CHECK(values->Dims() == indices->Dims()); + CHECK(!output_dims.empty()); + CHECK_EQ(values->Dims().size(), output_dims.size()); + CHECK_GT(values->Dims().back(), 0); + CHECK_GT(output_dims.back(), 0); + + const int64_t topk = values->Dims().back(); + const int64_t num_experts = output_dims.back(); + const int64_t rows = values->NumElements() / topk; + size_t output_numel = 1; + for (const auto dim : output_dims) { output_numel *= static_cast(dim); } + CHECK_EQ(output_numel, static_cast(rows * num_experts)); + + auto output = std::make_shared(output_dims, values->Dtype(), values->GetDevice()); + std::memset(output->DataPtr(), 0, output->SizeInBytes()); + + const size_t elem_size = kDataTypeToSize.at(values->Dtype()); + const auto *src = static_cast(values->DataPtr()); + auto *dst = static_cast(output->DataPtr()); + const auto *idx = static_cast(indices->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t expert_idx = idx[row * topk + selected]; + CHECK_GE(expert_idx, 0); + CHECK_LT(expert_idx, num_experts); + std::memcpy(dst + (row * num_experts + expert_idx) * elem_size, + src + (row * topk + selected) * elem_size, elem_size); + } + } + + return output; +} + +std::shared_ptr ScatterBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &indices) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterBackward expects int64 indices"; + CHECK_GE(grad_output->Dims().size(), 1); + CHECK_GE(indices->Dims().size(), 1); + + const int64_t num_experts = grad_output->Dims().back(); + const int64_t topk = indices->Dims().back(); + const int64_t rows = indices->NumElements() / topk; + CHECK_EQ(grad_output->NumElements(), static_cast(rows * num_experts)); + + auto grad_values = std::make_shared(indices->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + const size_t elem_size = kDataTypeToSize.at(grad_output->Dtype()); + const auto *src = static_cast(grad_output->DataPtr()); + auto *dst = static_cast(grad_values->DataPtr()); + const auto *idx = static_cast(indices->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t expert_idx = idx[row * topk + selected]; + CHECK_GE(expert_idx, 0); + CHECK_LT(expert_idx, num_experts); + std::memcpy(dst + (row * topk + selected) * elem_size, + src + (row * num_experts + expert_idx) * elem_size, elem_size); + } + } + + return grad_values; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_SCATTER_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_SCATTER_KERNEL(ScatterForward) +REGISTER_CPU_SCATTER_KERNEL(ScatterBackward) + +#undef REGISTER_CPU_SCATTER_KERNEL diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index b4bdafd8..6300ffdb 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -9,7 +9,7 @@ #include "infini_train/include/nn/functional.h" #include "infini_train/include/tensor.h" -namespace infini_train::kernels::cuda { +namespace infini_train::kernels::cuda::comm { std::vector> Broadcast(const std::vector> &input_tensors, const std::vector &devices) { @@ -69,11 +69,11 @@ std::shared_ptr Gather(const std::vector> &tenso auto view_kernel = Dispatcher::Instance().GetKernel({destination.type(), "NoOpForward"}); return view_kernel.Call>(gathered_tensor, new_dims); } -} // namespace infini_train::kernels::cuda +} // namespace infini_train::kernels::cuda::comm #define REGISTER_CUDA_COMM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, Comm##kernel_name, \ - infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, \ + infini_train::kernels::cuda::comm::kernel_name) REGISTER_CUDA_COMM_KERNEL(Broadcast) REGISTER_CUDA_COMM_KERNEL(Scatter) diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index 12d0567d..5216f28e 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -9,15 +9,11 @@ #include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" namespace infini_train::kernels::cuda { -// FIXME(zbl): This kernel aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency template -__global__ void IndexGatherForwardKernel(const T *__restrict__ input, const int64_t *__restrict__ norm_index, - T *__restrict__ output, const int64_t *__restrict__ out_dims, - const int64_t *__restrict__ in_strides, - const int64_t *__restrict__ out_strides, int num_dims, int gather_dim, - int64_t dim_size_gather, int64_t total_elements) { +__global__ void GatherForwardKernel(const T *__restrict__ input, const int64_t *__restrict__ norm_index, + T *__restrict__ output, const int64_t *__restrict__ out_dims, + const int64_t *__restrict__ in_strides, const int64_t *__restrict__ out_strides, + int num_dims, int gather_dim, int64_t dim_size_gather, int64_t total_elements) { int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; if (out_idx >= total_elements) { return; @@ -43,8 +39,8 @@ __global__ void IndexGatherForwardKernel(const T *__restrict__ input, const int6 output[out_idx] = input[in_linear]; } -std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const std::shared_ptr &index, - int64_t dim) { +std::shared_ptr GatherForward(const std::shared_ptr &input, const std::shared_ptr &index, + int64_t dim) { const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -103,23 +99,22 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, core::cuda::DispatchCudaFunc( dtype, [=]() { - IndexGatherForwardKernel<<>>( + GatherForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(index->DataPtr()), static_cast(out->DataPtr()), out_dims_dev, in_strides_dev, out_strides_dev, (int)num_dims, (int)dim, gather_dim_size, total_elements); }, - "CUDA IndexGatherForward"); + "CUDA GatherForward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); return out; } template -__global__ void IndexGatherBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ index, - T *__restrict__ grad_input, const int64_t *__restrict__ out_dims, - const int64_t *__restrict__ in_strides, - const int64_t *__restrict__ out_strides, int num_dims, int gather_dim, - int64_t dim_size_gather, int64_t total_elements) { +__global__ void GatherBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ index, + T *__restrict__ grad_input, const int64_t *__restrict__ out_dims, + const int64_t *__restrict__ in_strides, const int64_t *__restrict__ out_strides, + int num_dims, int gather_dim, int64_t dim_size_gather, int64_t total_elements) { int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; if (out_idx >= total_elements) { return; @@ -149,9 +144,8 @@ __global__ void IndexGatherBackwardKernel(const T *__restrict__ grad_output, con atomicAdd(&grad_input[in_linear], grad_output[out_idx]); } -std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &index, int64_t dim, - const std::vector &input_dims) { +std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, const std::shared_ptr &index, + int64_t dim, const std::vector &input_dims) { const auto &in_dims = input_dims; const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -210,12 +204,12 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ core::cuda::DispatchCudaFunc( dtype, [=]() { - IndexGatherBackwardKernel<<>>( + GatherBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(index->DataPtr()), static_cast(grad_input->DataPtr()), out_dims_dev, in_strides_dev, out_strides_dev, (int)num_dims, (int)dim, gather_dim_size, total_elements); }, - "CUDA IndexGatherBackward"); + "CUDA GatherBackward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); return grad_input; @@ -226,7 +220,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ #define REGISTER_CUDA_GATHER_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) -REGISTER_CUDA_GATHER_KERNEL(IndexGatherForward) -REGISTER_CUDA_GATHER_KERNEL(IndexGatherBackward) +REGISTER_CUDA_GATHER_KERNEL(GatherForward) +REGISTER_CUDA_GATHER_KERNEL(GatherBackward) #undef REGISTER_CUDA_GATHER_KERNEL diff --git a/infini_train/src/kernels/cuda/scatter.cu b/infini_train/src/kernels/cuda/scatter.cu new file mode 100644 index 00000000..9ebb173a --- /dev/null +++ b/infini_train/src/kernels/cuda/scatter.cu @@ -0,0 +1,120 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void ScatterForwardKernel(const T *__restrict__ values, const int64_t *__restrict__ indices, + T *__restrict__ output, int64_t rows, int64_t topk, int64_t num_experts) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total = rows * topk; + if (idx >= total) { + return; + } + + const int64_t row = idx / topk; + const int64_t expert_idx = indices[idx]; + output[row * num_experts + expert_idx] = values[idx]; +} + +std::shared_ptr ScatterForward(const std::shared_ptr &values, const std::shared_ptr &indices, + const std::vector &output_dims) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA ScatterForward expects int64 indices"; + CHECK(values->Dims() == indices->Dims()); + CHECK(!output_dims.empty()); + CHECK_EQ(values->Dims().size(), output_dims.size()); + CHECK_GT(values->Dims().back(), 0); + CHECK_GT(output_dims.back(), 0); + + const int64_t topk = values->Dims().back(); + const int64_t num_experts = output_dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = values->NumElements() / topk; + + auto output = std::make_shared(output_dims, values->Dtype(), values->GetDevice()); + + auto device = values->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMemsetAsync(output->DataPtr(), 0, output->SizeInBytes(), stream)); + const int threads = 256; + const int blocks = static_cast(((rows * topk) + threads - 1) / threads); + if (values->Dtype() == DataType::kBOOL) { + ScatterForwardKernel<<>>( + static_cast(values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(output->DataPtr()), rows, topk, num_experts); + CUDA_CHECK(cudaGetLastError()); + } else { + core::cuda::DispatchCudaFunc( + values->Dtype(), + [=]() { + ScatterForwardKernel<<>>( + static_cast(values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(output->DataPtr()), rows, topk, num_experts); + }, + "CUDA ScatterForward"); + } + return output; +} + +template +__global__ void ScatterBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ indices, + T *__restrict__ grad_values, int64_t rows, int64_t topk, int64_t num_experts) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total = rows * topk; + if (idx >= total) { + return; + } + const int64_t row = idx / topk; + const int64_t expert_idx = indices[idx]; + grad_values[idx] = grad_output[row * num_experts + expert_idx]; +} + +std::shared_ptr ScatterBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &indices) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA ScatterBackward expects int64 indices"; + CHECK_GE(grad_output->Dims().size(), 1); + CHECK_GE(indices->Dims().size(), 1); + const int64_t num_experts = grad_output->Dims().back(); + const int64_t topk = indices->Dims().back(); + const int64_t rows = indices->NumElements() / topk; + CHECK_EQ(grad_output->NumElements(), static_cast(rows * num_experts)); + + auto grad_values = std::make_shared(indices->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast(((rows * topk) + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + grad_output->Dtype(), + [=]() { + ScatterBackwardKernel<<>>( + static_cast(grad_output->DataPtr()), static_cast(indices->DataPtr()), + static_cast(grad_values->DataPtr()), rows, topk, num_experts); + }, + "CUDA ScatterBackward"); + + return grad_values; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_SCATTER_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_SCATTER_KERNEL(ScatterForward) +REGISTER_CUDA_SCATTER_KERNEL(ScatterBackward) + +#undef REGISTER_CUDA_SCATTER_KERNEL diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 31db11ec..ffd218d7 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -44,7 +44,7 @@ std::vector>> Scatter(const std::vector &devices, int dim) { std::vector>> output_tensors; for (const auto &tensor : input_tensors) { - output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); + output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); } std::vector>> transposed_output_tensors; transposed_output_tensors.resize(devices.size()); @@ -59,7 +59,7 @@ std::vector> Gather(const std::vector> gather_tensors; for (const auto &tensor : tensors) { gather_tensors.push_back(tensor[0]); } - return std::make_shared(target_device, dim)->Apply(gather_tensors); + return std::make_shared(target_device, dim)->Apply(gather_tensors); } std::vector>> @@ -67,7 +67,7 @@ BroadcastCoalescedReshape(const std::vector> &tensors, c if (tensors.empty()) { return {}; } - auto tensor_copies = std::make_shared(devices)->Apply(tensors); + auto tensor_copies = std::make_shared(devices)->Apply(tensors); std::vector>> tensor_copies_reshaped(devices.size()); for (int replica_idx = 0; replica_idx < devices.size(); ++replica_idx) { tensor_copies_reshaped[replica_idx].resize(tensors.size()); diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..76fb2f00 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -14,7 +14,8 @@ #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/function_hook.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/misc.h" +#include "infini_train/include/autograd/indexing.h" +#include "infini_train/include/autograd/no_op.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/transform.h" @@ -356,7 +357,7 @@ std::vector> Tensor::Split(int split_size, int dim) { std::shared_ptr Tensor::Gather(int dim, const std::shared_ptr &index) { CHECK(GetDevice() == index->GetDevice()) << "index must be on the same device as input."; - return std::make_shared(dim)->Apply({shared_from_this(), index})[0]; + return std::make_shared(dim)->Apply({shared_from_this(), index})[0]; } std::shared_ptr Tensor::RepeatInterleave(int64_t repeat, int64_t dim) { From e9567ce21c92348d7c38b5a4d84a36d1dc0f20e5 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 26 May 2026 12:24:26 +0000 Subject: [PATCH 06/11] test: migrate test_transformer_architecture to ctest framework --- .../include/autograd/{indexing.h => gather.h} | 17 - infini_train/include/autograd/transform.h | 17 + infini_train/include/datatype.h | 18 +- .../src/autograd/{indexing.cc => gather.cc} | 28 +- infini_train/src/autograd/transform.cc | 26 + infini_train/src/kernels/cpu/gather.cc | 7 +- infini_train/src/kernels/cpu/scatter.cc | 13 +- infini_train/src/nn/functional.cc | 1 - infini_train/src/tensor.cc | 2 +- .../test_transformer_architecture.cc | 624 ------------------ .../test_transformer_architecture.cc | 159 +++++ 11 files changed, 222 insertions(+), 690 deletions(-) rename infini_train/include/autograd/{indexing.h => gather.h} (51%) rename infini_train/src/autograd/{indexing.cc => gather.cc} (53%) delete mode 100644 test/transformer/test_transformer_architecture.cc diff --git a/infini_train/include/autograd/indexing.h b/infini_train/include/autograd/gather.h similarity index 51% rename from infini_train/include/autograd/indexing.h rename to infini_train/include/autograd/gather.h index f35321c8..0fb44c51 100644 --- a/infini_train/include/autograd/indexing.h +++ b/infini_train/include/autograd/gather.h @@ -27,21 +27,4 @@ class Gather : public Function { std::vector input_dims_; }; -class Slice : public Function { -public: - static constexpr char kType[] = "SliceFunction"; - - Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) - : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const std::vector starts_; - const std::vector ends_; - const std::vector steps_; -}; - } // namespace infini_train::autograd diff --git a/infini_train/include/autograd/transform.h b/infini_train/include/autograd/transform.h index e345061a..88b7d56e 100644 --- a/infini_train/include/autograd/transform.h +++ b/infini_train/include/autograd/transform.h @@ -127,4 +127,21 @@ class Concat : public Function { std::vector> input_dims_list_; }; +class Slice : public Function { +public: + static constexpr char kType[] = "SliceFunction"; + + Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) + : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const std::vector starts_; + const std::vector ends_; + const std::vector steps_; +}; + } // namespace infini_train::autograd diff --git a/infini_train/include/datatype.h b/infini_train/include/datatype.h index 0ae8fda7..6efa849c 100644 --- a/infini_train/include/datatype.h +++ b/infini_train/include/datatype.h @@ -100,18 +100,18 @@ enum class DataType : int8_t { }; inline const std::unordered_map kDataTypeToSize = { - {DataType::kBOOL, 1}, - {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2}, - {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8}, - {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8}, + {DataType::kBOOL, 1}, {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, + {DataType::kINT16, 2}, {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, + {DataType::kINT64, 8}, {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, + {DataType::kFLOAT64, 8}, }; inline const std::unordered_map kDataTypeToDesc = { - {DataType::kBOOL, "bool"}, - {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"}, - {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"}, - {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"}, - {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"}, + {DataType::kBOOL, "bool"}, {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, + {DataType::kUINT16, "uint16"}, {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, + {DataType::kINT32, "int32"}, {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, + {DataType::kBFLOAT16, "bf16"}, {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, + {DataType::kFLOAT64, "fp64"}, }; // ============================================================================= diff --git a/infini_train/src/autograd/indexing.cc b/infini_train/src/autograd/gather.cc similarity index 53% rename from infini_train/src/autograd/indexing.cc rename to infini_train/src/autograd/gather.cc index a9642256..a30cb013 100644 --- a/infini_train/src/autograd/indexing.cc +++ b/infini_train/src/autograd/gather.cc @@ -1,4 +1,4 @@ -#include "infini_train/include/autograd/indexing.h" +#include "infini_train/include/autograd/gather.h" #include "glog/logging.h" @@ -34,30 +34,4 @@ std::vector> Gather::Backward(const std::vector>(grad_output, index, dim_, input_dims_), nullptr}; } -std::vector> Slice::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return { - Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; -} - -void Slice::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - // FIXME(dcj): only input's dim need to be saved - const auto &input = input_tensors[0]; - saved_tensors_ = {input}; -} - -std::vector> Slice::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; - const auto &grad_output = grad_outputs[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, - ends_, steps_)}; -} - } // namespace infini_train::autograd diff --git a/infini_train/src/autograd/transform.cc b/infini_train/src/autograd/transform.cc index a85e9cd4..e38d5616 100644 --- a/infini_train/src/autograd/transform.cc +++ b/infini_train/src/autograd/transform.cc @@ -153,4 +153,30 @@ std::vector> Concat::Backward(const std::vector>>(grad_output, input_dims_list_, dim_); } +std::vector> Slice::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; +} + +void Slice::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + // FIXME(dcj): only input's dim need to be saved + const auto &input = input_tensors[0]; + saved_tensors_ = {input}; +} + +std::vector> Slice::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(saved_tensors_.size(), 1); + const auto &input = saved_tensors_[0]; + const auto &grad_output = grad_outputs[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, + ends_, steps_)}; +} + } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/gather.cc b/infini_train/src/kernels/cpu/gather.cc index bd28fea8..af39fc0f 100644 --- a/infini_train/src/kernels/cpu/gather.cc +++ b/infini_train/src/kernels/cpu/gather.cc @@ -9,7 +9,7 @@ namespace infini_train::kernels::cpu { std::shared_ptr GatherForward(const std::shared_ptr &input, const std::shared_ptr &index, - int64_t dim) { + int64_t dim) { const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -100,9 +100,8 @@ std::shared_ptr GatherForward(const std::shared_ptr &input, cons return out; } -std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &index, int64_t dim, - const std::vector &input_dims) { +std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, const std::shared_ptr &index, + int64_t dim, const std::vector &input_dims) { const auto &in_dims = input_dims; const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); diff --git a/infini_train/src/kernels/cpu/scatter.cc b/infini_train/src/kernels/cpu/scatter.cc index b98326ea..1a9cf62e 100644 --- a/infini_train/src/kernels/cpu/scatter.cc +++ b/infini_train/src/kernels/cpu/scatter.cc @@ -10,8 +10,7 @@ namespace infini_train::kernels::cpu { -std::shared_ptr ScatterForward(const std::shared_ptr &values, - const std::shared_ptr &indices, +std::shared_ptr ScatterForward(const std::shared_ptr &values, const std::shared_ptr &indices, const std::vector &output_dims) { CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterForward expects int64 indices"; CHECK(values->Dims() == indices->Dims()); @@ -39,8 +38,8 @@ std::shared_ptr ScatterForward(const std::shared_ptr &values, const int64_t expert_idx = idx[row * topk + selected]; CHECK_GE(expert_idx, 0); CHECK_LT(expert_idx, num_experts); - std::memcpy(dst + (row * num_experts + expert_idx) * elem_size, - src + (row * topk + selected) * elem_size, elem_size); + std::memcpy(dst + (row * num_experts + expert_idx) * elem_size, src + (row * topk + selected) * elem_size, + elem_size); } } @@ -68,8 +67,8 @@ std::shared_ptr ScatterBackward(const std::shared_ptr &grad_outp const int64_t expert_idx = idx[row * topk + selected]; CHECK_GE(expert_idx, 0); CHECK_LT(expert_idx, num_experts); - std::memcpy(dst + (row * topk + selected) * elem_size, - src + (row * num_experts + expert_idx) * elem_size, elem_size); + std::memcpy(dst + (row * topk + selected) * elem_size, src + (row * num_experts + expert_idx) * elem_size, + elem_size); } } @@ -78,7 +77,7 @@ std::shared_ptr ScatterBackward(const std::shared_ptr &grad_outp } // namespace infini_train::kernels::cpu -#define REGISTER_CPU_SCATTER_KERNEL(kernel_name) \ +#define REGISTER_CPU_SCATTER_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) REGISTER_CPU_SCATTER_KERNEL(ScatterForward) diff --git a/infini_train/src/nn/functional.cc b/infini_train/src/nn/functional.cc index b02f185a..c33e2368 100644 --- a/infini_train/src/nn/functional.cc +++ b/infini_train/src/nn/functional.cc @@ -6,7 +6,6 @@ #include "infini_train/include/autograd/activations.h" #include "infini_train/include/autograd/elementwise.h" -#include "infini_train/include/autograd/misc.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" #include "infini_train/include/autograd/transform.h" diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 76fb2f00..44860a0f 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -13,8 +13,8 @@ #include "infini_train/include/autograd/elementwise.h" #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/autograd/gather.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/indexing.h" #include "infini_train/include/autograd/no_op.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" diff --git a/test/transformer/test_transformer_architecture.cc b/test/transformer/test_transformer_architecture.cc deleted file mode 100644 index 42efda2d..00000000 --- a/test/transformer/test_transformer_architecture.cc +++ /dev/null @@ -1,624 +0,0 @@ -#include -#include -#include - -#include "glog/logging.h" - -#include "example/gpt2/config.h" -#include "example/llama3/config.h" -#include "infini_train/include/nn/modules/activations.h" -#include "infini_train/include/nn/modules/normalization.h" -#include "infini_train/include/nn/modules/sparse.h" -#include "infini_train/include/nn/modules/transformer/causal_self_attention.h" -#include "infini_train/include/nn/modules/transformer/mlp.h" -#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" -#include "infini_train/include/nn/modules/transformer/transformer.h" -#include "infini_train/include/nn/modules/transformer/transformer_config.h" -#include "infini_train/include/nn/modules/transformer/utils.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/tensor.h" - -using namespace infini_train; -namespace nn = infini_train::nn; - -// ============================================================================ -// Test 1: TransformerConfig Validation -// ============================================================================ -void TestConfigValidation() { - std::cout << "\n=== Test 1: TransformerConfig Validation ===" << std::endl; - - bool all_passed = true; - - // Test GPT2 config - auto gpt2_config = gpt2::GPT2Config(); - if (gpt2_config.attention_type != nn::AttentionType::kStandard) { - std::cout << "FAIL: GPT2 config should use Standard attention" << std::endl; - all_passed = false; - } - if (gpt2_config.activation_type != nn::MLPType::kGELU) { - std::cout << "FAIL: GPT2 config should use GELU activation" << std::endl; - all_passed = false; - } - if (gpt2_config.norm_type != nn::NormType::kLayerNorm) { - std::cout << "FAIL: GPT2 config should use LayerNorm" << std::endl; - all_passed = false; - } - if (!gpt2_config.add_bias_linear) { - std::cout << "FAIL: GPT2 config should have bias enabled" << std::endl; - all_passed = false; - } - if (!gpt2_config.tie_weights) { - std::cout << "FAIL: GPT2 config should have tied weights" << std::endl; - all_passed = false; - } - - // Test LLaMA3 config - auto llama3_config = llama3::LLaMA3Config(); - if (llama3_config.attention_type != nn::AttentionType::kRoPE) { - std::cout << "FAIL: LLaMA3 config should use RoPE attention" << std::endl; - all_passed = false; - } - if (llama3_config.activation_type != nn::MLPType::kSwiGLU) { - std::cout << "FAIL: LLaMA3 config should use SwiGLU activation" << std::endl; - all_passed = false; - } - if (llama3_config.norm_type != nn::NormType::kRMSNorm) { - std::cout << "FAIL: LLaMA3 config should use RMSNorm" << std::endl; - all_passed = false; - } - if (llama3_config.add_bias_linear) { - std::cout << "FAIL: LLaMA3 config should have bias disabled" << std::endl; - all_passed = false; - } - if (llama3_config.tie_weights) { - std::cout << "FAIL: LLaMA3 config should not have tied weights" << std::endl; - all_passed = false; - } - - // Test GQA detection - if (!llama3_config.UseGQA()) { - std::cout << "FAIL: LLaMA3 config should detect GQA (n_kv_head < n_head)" << std::endl; - all_passed = false; - } - if (gpt2_config.UseGQA()) { - std::cout << "FAIL: GPT2 config should not detect GQA (n_kv_head == n_head)" << std::endl; - all_passed = false; - } - - if (all_passed) { - std::cout << "SUCCESS: All config validations passed!" << std::endl; - } -} - -// ============================================================================ -// Test 2: Embedding Layer -// ============================================================================ -void TestEmbedding() { - std::cout << "\n=== Test 2: Embedding Layer ===" << std::endl; - - const int64_t vocab_size = 1000; - const int64_t embedding_dim = 128; - const int64_t batch_size = 2; - const int64_t seq_len = 16; - - try { - auto embedding = std::make_shared(vocab_size, embedding_dim); - - // Check parameters - auto params = embedding->Parameters(); - if (params.size() != 1) { - std::cout << "FAIL: Embedding should have 1 parameter, got " << params.size() << std::endl; - return; - } - - // Check weight shape - auto weight = embedding->parameter(nn::Embedding::kParamWeightName); - if (weight->Dims() != std::vector{vocab_size, embedding_dim}) { - std::cout << "FAIL: Embedding weight shape mismatch" << std::endl; - return; - } - - // Forward pass - auto input = std::make_shared(std::vector{batch_size, seq_len}, DataType::kINT64); - auto output = (*embedding)({input}); - - if (output.size() != 1) { - std::cout << "FAIL: Embedding forward should return 1 tensor" << std::endl; - return; - } - - const auto &out_dims = output[0]->Dims(); - if (out_dims != std::vector{batch_size, seq_len, embedding_dim}) { - std::cout << "FAIL: Embedding output shape mismatch. Expected [" << batch_size << ", " << seq_len << ", " - << embedding_dim << "], got [" << out_dims[0] << ", " << out_dims[1] << ", " << out_dims[2] << "]" - << std::endl; - return; - } - - std::cout << "SUCCESS: Embedding layer works correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 3: Normalization Layers (LayerNorm vs RMSNorm) -// ============================================================================ -void TestNormalization() { - std::cout << "\n=== Test 3: Normalization Layers ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - - try { - // Test LayerNorm - auto layernorm = std::make_shared(std::vector{hidden_size}); - auto ln_params = layernorm->Parameters(); - if (ln_params.size() != 2) { - std::cout << "FAIL: LayerNorm should have 2 parameters (weight, bias), got " << ln_params.size() - << std::endl; - return; - } - - // Test RMSNorm - auto rmsnorm = std::make_shared(hidden_size); - auto rms_params = rmsnorm->Parameters(); - if (rms_params.size() != 1) { - std::cout << "FAIL: RMSNorm should have 1 parameter (weight), got " << rms_params.size() << std::endl; - return; - } - - // Forward pass for both - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto ln_output = (*layernorm)({input}); - auto rms_output = (*rmsnorm)({input}); - - if (ln_output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: LayerNorm output shape mismatch" << std::endl; - return; - } - - if (rms_output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: RMSNorm output shape mismatch" << std::endl; - return; - } - - std::cout << "SUCCESS: Normalization layers work correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 4: MLP Layer (GELU vs SwiGLU) -// ============================================================================ -void TestMlp() { - std::cout << "\n=== Test 4: MLP Layer ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - - try { - // Test GPT2-style MLP (GELU) - nn::TransformerConfig gpt2_mlp_config; - gpt2_mlp_config.n_embd = hidden_size; - gpt2_mlp_config.activation_type = nn::MLPType::kGELU; - gpt2_mlp_config.ffn_expansion_ratio = 4.0f; - gpt2_mlp_config.add_bias_linear = true; - - auto gpt2_mlp = std::make_shared(gpt2_mlp_config); - auto gpt2_params = gpt2_mlp->Parameters(); - - // GPT2 MLP should have: c_fc.weight, c_fc.bias, c_proj.weight, c_proj.bias - if (gpt2_params.size() != 4) { - std::cout << "FAIL: GPT2 MLP should have 4 parameters, got " << gpt2_params.size() << std::endl; - return; - } - - // Test LLaMA3-style MLP (SwiGLU) - nn::TransformerConfig llama3_mlp_config; - llama3_mlp_config.n_embd = hidden_size; - llama3_mlp_config.activation_type = nn::MLPType::kSwiGLU; - llama3_mlp_config.ffn_expansion_ratio = 4.0f; - llama3_mlp_config.add_bias_linear = false; - llama3_mlp_config.ffn_dim_multiplier = 1.5f; - llama3_mlp_config.multiple_of = 256; - - auto llama3_mlp = std::make_shared(llama3_mlp_config); - auto llama3_params = llama3_mlp->Parameters(); - - // LLaMA3 MLP should have: c_fc.weight, c_fc2.weight, c_proj.weight (no bias) - if (llama3_params.size() != 3) { - std::cout << "FAIL: LLaMA3 MLP should have 3 parameters, got " << llama3_params.size() << std::endl; - return; - } - - // Forward pass - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto gpt2_output = (*gpt2_mlp)({input}); - auto llama3_output = (*llama3_mlp)({input}); - - // Output should have same hidden dimension - if (gpt2_output[0]->Dims()[2] != hidden_size) { - std::cout << "FAIL: GPT2 MLP output hidden dim mismatch" << std::endl; - return; - } - - if (llama3_output[0]->Dims()[2] != hidden_size) { - std::cout << "FAIL: LLaMA3 MLP output hidden dim mismatch" << std::endl; - return; - } - - std::cout << "SUCCESS: MLP layers work correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 5: CausalSelfAttention -// ============================================================================ -void TestAttention() { - std::cout << "\n=== Test 5: CausalSelfAttention ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - const int64_t n_head = 4; - - try { - // Test standard attention (GPT2-style) - nn::TransformerConfig standard_config; - standard_config.n_embd = hidden_size; - standard_config.n_head = n_head; - standard_config.n_kv_head = n_head; - standard_config.attention_type = nn::AttentionType::kStandard; - standard_config.add_bias_linear = true; - - auto standard_attn = std::make_shared(standard_config); - auto standard_params = standard_attn->Parameters(); - - // Should have c_attn (QKV combined) and c_proj with biases - if (standard_params.size() != 4) { - std::cout << "FAIL: Standard attention should have 4 parameters, got " << standard_params.size() - << std::endl; - return; - } - - // Test RoPE attention with GQA (LLaMA3-style) - nn::TransformerConfig rope_config; - rope_config.n_embd = hidden_size; - rope_config.n_head = n_head; - rope_config.n_kv_head = 2; // GQA: fewer KV heads - rope_config.attention_type = nn::AttentionType::kRoPE; - rope_config.add_bias_linear = false; - - auto rope_attn = std::make_shared(rope_config); - auto rope_params = rope_attn->Parameters(); - - // RoPE attention without bias should have fewer params - if (rope_params.empty()) { - std::cout << "FAIL: RoPE attention should have parameters" << std::endl; - return; - } - - // Forward pass - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto standard_output = (*standard_attn)({input}); - if (standard_output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: Standard attention output shape mismatch" << std::endl; - return; - } - - std::cout << "SUCCESS: CausalSelfAttention works correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 6: TransformerLayer -// ============================================================================ -void TestTransformerLayer() { - std::cout << "\n=== Test 6: TransformerLayer ===" << std::endl; - - const int64_t hidden_size = 64; - const int64_t batch_size = 2; - const int64_t seq_len = 8; - - try { - // Test GPT2-style layer - auto gpt2_config = gpt2::GPT2Config(); - gpt2_config.n_embd = hidden_size; - gpt2_config.n_head = 4; - gpt2_config.n_layer = 1; - - auto gpt2_layer = std::make_shared(gpt2_config); - auto gpt2_params = gpt2_layer->Parameters(); - - if (gpt2_params.empty()) { - std::cout << "FAIL: GPT2 TransformerLayer should have parameters" << std::endl; - return; - } - - // Forward pass - auto input - = std::make_shared(std::vector{batch_size, seq_len, hidden_size}, DataType::kFLOAT32); - - auto output = (*gpt2_layer)({input}); - if (output[0]->Dims() != input->Dims()) { - std::cout << "FAIL: TransformerLayer output shape mismatch" << std::endl; - return; - } - - // Test LLaMA3-style layer - auto llama3_config = llama3::LLaMA3Config(); - llama3_config.n_embd = hidden_size; - llama3_config.n_head = 4; - llama3_config.n_kv_head = 2; - llama3_config.n_layer = 1; - - auto llama3_layer = std::make_shared(llama3_config); - auto llama3_params = llama3_layer->Parameters(); - - if (llama3_params.empty()) { - std::cout << "FAIL: LLaMA3 TransformerLayer should have parameters" << std::endl; - return; - } - - std::cout << "SUCCESS: TransformerLayer works correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 7: TransformerModel Instantiation (GPT2) -// ============================================================================ -void TestGpt2Model() { - std::cout << "\n=== Test 7: GPT2 Model Instantiation ===" << std::endl; - - auto config = gpt2::GPT2Config(); - // Use smaller config for faster testing - config.n_layer = 2; - config.n_head = 4; - config.n_embd = 64; - - try { - auto model = std::make_shared(config); - - if (model == nullptr) { - std::cout << "FAIL: Failed to create GPT2 model" << std::endl; - return; - } - - auto params = model->Parameters(); - if (params.empty()) { - std::cout << "FAIL: GPT2 model has no parameters" << std::endl; - return; - } - - std::cout << "SUCCESS: GPT2 model created with " << params.size() << " parameters!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 8: TransformerModel Instantiation (LLaMA3) -// ============================================================================ -void TestLlama3Model() { - std::cout << "\n=== Test 8: LLaMA3 Model Instantiation ===" << std::endl; - - auto config = llama3::LLaMA3Config(); - // Use smaller config for faster testing - config.n_layer = 2; - config.n_head = 4; - config.n_kv_head = 2; - config.n_embd = 64; - - try { - auto model = std::make_shared(config); - - if (model == nullptr) { - std::cout << "FAIL: Failed to create LLaMA3 model" << std::endl; - return; - } - - auto params = model->Parameters(); - if (params.empty()) { - std::cout << "FAIL: LLaMA3 model has no parameters" << std::endl; - return; - } - - std::cout << "SUCCESS: LLaMA3 model created with " << params.size() << " parameters!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 9: RoPE Utilities -// ============================================================================ -void TestRopeUtils() { - std::cout << "\n=== Test 9: RoPE Utilities ===" << std::endl; - - const int64_t head_dim = 64; - const int64_t seq_len = 128; - - try { - // Test precompute freqs_cis - auto freqs_cis = PrecomputeFreqsCis(head_dim, seq_len); - - // freqs_cis shape: [seq_len, head_dim/2, 2] (cos and sin stacked on last dim) - const auto &dims = freqs_cis->Dims(); - if (dims.size() != 3) { - std::cout << "FAIL: freqs_cis should be 3D, got " << dims.size() << "D" << std::endl; - return; - } - if (dims[0] != seq_len) { - std::cout << "FAIL: freqs_cis seq_len mismatch. Expected " << seq_len << ", got " << dims[0] << std::endl; - return; - } - if (dims[1] != head_dim / 2) { - std::cout << "FAIL: freqs_cis head_dim/2 mismatch. Expected " << head_dim / 2 << ", got " << dims[1] - << std::endl; - return; - } - if (dims[2] != 2) { - std::cout << "FAIL: freqs_cis last dim should be 2 (cos, sin), got " << dims[2] << std::endl; - return; - } - - std::cout << "SUCCESS: RoPE utilities work correctly!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 10: Model StateDict -// ============================================================================ -void TestStateDict() { - std::cout << "\n=== Test 10: Model StateDict ===" << std::endl; - - nn::TransformerConfig config; - config.n_layer = 1; - config.n_head = 2; - config.n_kv_head = 2; // Must set explicitly - config.n_embd = 32; - config.vocab_size = 1000; - config.attention_type = nn::AttentionType::kStandard; - config.activation_type = nn::MLPType::kGELU; - config.norm_type = nn::NormType::kLayerNorm; - config.add_bias_linear = true; - - try { - auto model = std::make_shared(config); - auto state_dict = model->StateDict(); - - if (state_dict.empty()) { - std::cout << "FAIL: StateDict should not be empty" << std::endl; - return; - } - - // StateDict includes both parameters and buffers, so it should have >= parameters count - auto params = model->Parameters(); - auto buffers = model->Buffers(); - - if (state_dict.size() < params.size()) { - std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should be >= parameter count (" - << params.size() << ")" << std::endl; - return; - } - - // Expected: state_dict.size() == params.size() + buffers.size() - size_t expected_size = params.size() + buffers.size(); - if (state_dict.size() != expected_size) { - std::cout << "FAIL: StateDict size (" << state_dict.size() << ") should equal params (" << params.size() - << ") + buffers (" << buffers.size() << ") = " << expected_size << std::endl; - return; - } - - std::cout << "SUCCESS: StateDict works correctly with " << state_dict.size() << " entries (" << params.size() - << " params + " << buffers.size() << " buffers)!" << std::endl; - - } catch (const std::exception &e) { std::cout << "FAIL: Exception: " << e.what() << std::endl; } -} - -// ============================================================================ -// Test 11: MoE Layer -// ============================================================================ -void TestMoELayer() { - std::cout << "\n=== Test 11: MoE Layer ===" << std::endl; - - nn::TransformerConfig config; - config.n_embd = 32; - config.n_head = 2; - config.n_kv_head = 2; - config.activation_type = nn::MLPType::kGELU; - config.add_bias_linear = true; - config.ffn_type = nn::FFNType::kMoE; - config.moe_config = nn::MoEConfig{}; - config.moe_config->num_experts = 2; - config.moe_config->router_topk = 1; - - auto moe = std::make_shared(config); - auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); - input->Uniform(); - - auto output = (*moe)({input}); - CHECK_EQ(output.size(), 1); - CHECK(output[0]->Dims() == input->Dims()); - - auto params = moe->Parameters(); - CHECK(!params.empty()); - - std::cout << "SUCCESS: MoE layer forward works correctly!" << std::endl; -} - -void TestMoELayerTop2() { - std::cout << "\n=== Test 12: MoE Layer Top-2 ===" << std::endl; - - nn::TransformerConfig config; - config.n_embd = 32; - config.n_head = 2; - config.n_kv_head = 2; - config.activation_type = nn::MLPType::kSwiGLU; - config.add_bias_linear = false; - config.ffn_type = nn::FFNType::kMoE; - config.moe_config = nn::MoEConfig{}; - config.moe_config->num_experts = 4; - config.moe_config->router_topk = 2; - config.moe_config->moe_ffn_hidden_size = 48; - - auto moe = std::make_shared(config); - auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32); - input->Uniform(); - - auto output = (*moe)({input}); - CHECK_EQ(output.size(), 1); - CHECK(output[0]->Dims() == input->Dims()); - - auto state = moe->StateDict(); - CHECK(state.contains("experts.expert_0.c_fc.weight")); - CHECK(state.contains("experts.expert_0.c_fc2.weight")); - CHECK(state.contains("experts.expert_0.c_proj.weight")); - CHECK(state.at("experts.expert_0.c_fc.weight")->Dims() == std::vector({48, config.n_embd})); - CHECK(state.at("experts.expert_0.c_fc2.weight")->Dims() == std::vector({48, config.n_embd})); - CHECK(state.at("experts.expert_0.c_proj.weight")->Dims() == std::vector({config.n_embd, 48})); - - std::cout << "SUCCESS: MoE layer top-2 forward works correctly!" << std::endl; -} - -// ============================================================================ -// Main -// ============================================================================ -int main(int argc, char *argv[]) { - google::InitGoogleLogging(argv[0]); - - nn::parallel::global::GlobalEnv::Instance().Init(1, 1, false, 1, 1); - - std::cout << "========================================" << std::endl; - std::cout << " Transformer architecture Tests" << std::endl; - std::cout << "========================================" << std::endl; - - TestConfigValidation(); - TestEmbedding(); - TestNormalization(); - TestMlp(); - TestAttention(); - TestTransformerLayer(); - TestGpt2Model(); - TestLlama3Model(); - TestRopeUtils(); - TestStateDict(); - TestMoELayer(); - TestMoELayerTop2(); - - std::cout << "\n========================================" << std::endl; - std::cout << " All Tests Completed" << std::endl; - std::cout << "========================================" << std::endl; - - return 0; -} diff --git a/tests/transformer/test_transformer_architecture.cc b/tests/transformer/test_transformer_architecture.cc index ba62e1e3..ad7a9da3 100644 --- a/tests/transformer/test_transformer_architecture.cc +++ b/tests/transformer/test_transformer_architecture.cc @@ -4,10 +4,13 @@ #include "gtest/gtest.h" +#include "infini_train/include/autograd/topk.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/moe/moe_layer.h" +#include "infini_train/include/nn/modules/transformer/moe/router.h" #include "infini_train/include/nn/modules/transformer/transformer.h" #include "infini_train/include/nn/modules/transformer/transformer_config.h" #include "infini_train/include/nn/modules/transformer/utils.h" @@ -189,4 +192,160 @@ TEST_P(TransformerModuleTest, StateDict) { EXPECT_GE(state_dict.size(), params.size()); } + +TEST_P(TransformerModuleTest, MoELayerTop1) { + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kGELU; + config.add_bias_linear = true; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 2; + config.moe_config->router_topk = 1; + config.moe_config->router_pre_softmax = true; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + EXPECT_FALSE(moe->Parameters().empty()); +} + +TEST_P(TransformerModuleTest, MoELayerTop2SwiGLU) { + nn::TransformerConfig config; + config.n_embd = 32; + config.n_head = 2; + config.n_kv_head = 2; + config.activation_type = nn::MLPType::kSwiGLU; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + config.moe_config->moe_ffn_hidden_size = 48; + + auto moe = std::make_shared(config); + moe->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*moe)({input}); + ASSERT_EQ(output.size(), 1); + EXPECT_EQ(output[0]->Dims(), input->Dims()); + + auto state = moe->StateDict(); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_fc2.weight")); + ASSERT_TRUE(state.contains("experts.expert_0.c_proj.weight")); + EXPECT_EQ(state.at("experts.expert_0.c_fc.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_fc2.weight")->Dims(), (std::vector{48, config.n_embd})); + EXPECT_EQ(state.at("experts.expert_0.c_proj.weight")->Dims(), (std::vector{config.n_embd, 48})); +} + +TEST_P(TransformerModuleTest, TopKRouterMegatronOutputs) { + nn::TransformerConfig config; + config.n_embd = 32; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 4; + config.moe_config->router_topk = 2; + + auto router = std::make_shared(config); + router->To(GetDevice()); + auto input = std::make_shared(std::vector{2, 4, config.n_embd}, DataType::kFLOAT32, GetDevice()); + input->Uniform(); + + auto output = (*router)({input}); + ASSERT_EQ(output.size(), 2); + EXPECT_EQ(output[0]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[1]->Dims(), (std::vector{2, 4, 4})); + EXPECT_EQ(output[0]->Dtype(), DataType::kFLOAT32); + EXPECT_EQ(output[1]->Dtype(), DataType::kBOOL); +} + +TEST_P(TransformerModuleTest, TopKTorchInterface) { + ONLY_CPU(); + const float data[] = {1.0f, 5.0f, 2.0f, 4.0f, 3.0f, 0.0f}; + auto input = std::make_shared(data, std::vector{2, 3}, DataType::kFLOAT32); + + auto largest_topk = std::make_shared(2, 1, true, true); + auto largest_values = largest_topk->Apply({input})[0]; + auto largest_indices = largest_topk->TopIndices(); + ASSERT_EQ(largest_values->Dims(), (std::vector{2, 2})); + ASSERT_EQ(largest_indices->Dims(), (std::vector{2, 2})); + const auto *largest_values_ptr = static_cast(largest_values->DataPtr()); + const auto *largest_indices_ptr = static_cast(largest_indices->DataPtr()); + EXPECT_FLOAT_EQ(largest_values_ptr[0], 5.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[1], 2.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[2], 4.0f); + EXPECT_FLOAT_EQ(largest_values_ptr[3], 3.0f); + EXPECT_EQ(largest_indices_ptr[0], 1); + EXPECT_EQ(largest_indices_ptr[1], 2); + EXPECT_EQ(largest_indices_ptr[2], 0); + EXPECT_EQ(largest_indices_ptr[3], 1); + + auto smallest_topk = std::make_shared(1, 0, false, true); + auto smallest_values = smallest_topk->Apply({input})[0]; + auto smallest_indices = smallest_topk->TopIndices(); + ASSERT_EQ(smallest_values->Dims(), (std::vector{1, 3})); + ASSERT_EQ(smallest_indices->Dims(), (std::vector{1, 3})); + const auto *smallest_values_ptr = static_cast(smallest_values->DataPtr()); + const auto *smallest_indices_ptr = static_cast(smallest_indices->DataPtr()); + EXPECT_FLOAT_EQ(smallest_values_ptr[0], 1.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[1], 3.0f); + EXPECT_FLOAT_EQ(smallest_values_ptr[2], 0.0f); + EXPECT_EQ(smallest_indices_ptr[0], 0); + EXPECT_EQ(smallest_indices_ptr[1], 1); + EXPECT_EQ(smallest_indices_ptr[2], 1); +} + +TEST_P(TransformerModuleTest, TopKRouterNormalization) { + ONLY_CPU(); + auto make_router = [](nn::MoEConfig::RouterScoreFunction score_function, bool pre_softmax) { + nn::TransformerConfig config; + config.n_embd = 2; + config.add_bias_linear = false; + config.ffn_type = nn::FFNType::kMoE; + config.moe_config = nn::MoEConfig{}; + config.moe_config->num_experts = 3; + config.moe_config->router_topk = 2; + config.moe_config->router_score_function = score_function; + config.moe_config->router_pre_softmax = pre_softmax; + auto router = std::make_shared(config); + auto weight = router->parameter(nn::moe::TopKRouter::kParamWeightName); + auto *weight_ptr = static_cast(weight->DataPtr()); + weight_ptr[0] = 1.0f; + weight_ptr[1] = 0.0f; + weight_ptr[2] = 2.0f; + weight_ptr[3] = 0.0f; + weight_ptr[4] = 0.0f; + weight_ptr[5] = 0.0f; + return router; + }; + + const float input_data[] = {1.0f, 1.0f}; + auto input = std::make_shared(input_data, std::vector{1, 1, 2}, DataType::kFLOAT32); + + auto softmax_router = make_router(nn::MoEConfig::RouterScoreFunction::kSoftmax, false); + auto softmax_output = (*softmax_router)({input}); + const auto *softmax_probs = static_cast(softmax_output[0]->DataPtr()); + EXPECT_NEAR(softmax_probs[0] + softmax_probs[1] + softmax_probs[2], 1.0f, 1e-5f); + EXPECT_GT(softmax_probs[1], softmax_probs[0]); + EXPECT_FLOAT_EQ(softmax_probs[2], 0.0f); + + auto sigmoid_router = make_router(nn::MoEConfig::RouterScoreFunction::kSigmoid, true); + auto sigmoid_output = (*sigmoid_router)({input}); + const auto *sigmoid_probs = static_cast(sigmoid_output[0]->DataPtr()); + EXPECT_NEAR(sigmoid_probs[0] + sigmoid_probs[1] + sigmoid_probs[2], 1.0f, 1e-5f); + EXPECT_GT(sigmoid_probs[1], sigmoid_probs[0]); + EXPECT_FLOAT_EQ(sigmoid_probs[2], 0.0f); +} + INFINI_TRAIN_REGISTER_TEST(TransformerModuleTest); From 4a3212c4533bcab3b3a86bc2a04366228342a6ea Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 27 May 2026 02:51:20 +0000 Subject: [PATCH 07/11] refactor: rename topk_mask to topk and align with torch.topk API --- infini_train/include/autograd/topk.h | 40 ++++++ infini_train/include/autograd/topk_mask.h | 29 ---- infini_train/src/autograd/topk.cc | 39 ++++++ infini_train/src/autograd/topk_mask.cc | 32 ----- infini_train/src/kernels/cpu/topk.cc | 124 +++++++++++++++++ infini_train/src/kernels/cpu/topk_mask.cc | 88 ------------ infini_train/src/kernels/cuda/topk.cu | 155 +++++++++++++++++++++ infini_train/src/kernels/cuda/topk_mask.cu | 118 ---------------- 8 files changed, 358 insertions(+), 267 deletions(-) create mode 100644 infini_train/include/autograd/topk.h delete mode 100644 infini_train/include/autograd/topk_mask.h create mode 100644 infini_train/src/autograd/topk.cc delete mode 100644 infini_train/src/autograd/topk_mask.cc create mode 100644 infini_train/src/kernels/cpu/topk.cc delete mode 100644 infini_train/src/kernels/cpu/topk_mask.cc create mode 100644 infini_train/src/kernels/cuda/topk.cu delete mode 100644 infini_train/src/kernels/cuda/topk_mask.cu diff --git a/infini_train/include/autograd/topk.h b/infini_train/include/autograd/topk.h new file mode 100644 index 00000000..7752efca --- /dev/null +++ b/infini_train/include/autograd/topk.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +// FIXME(dcj): Align this API with torch.topk and return both values and indices from Forward once +// InfiniTrain autograd supports marking individual outputs as non-differentiable. Today indices +// are exposed through TopIndices() to avoid waiting for gradients on metadata outputs. +class TopK : public Function { +public: + static constexpr char kType[] = "TopKFunction"; + + explicit TopK(int64_t topk, int64_t dim = -1, bool largest = true, bool sorted = true) + : Function(kType), topk_(topk), dim_(dim), largest_(largest), sorted_(sorted) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + + std::shared_ptr TopIndices() const; + +private: + int64_t topk_ = 1; + int64_t dim_ = -1; + bool largest_ = true; + bool sorted_ = true; + std::shared_ptr top_indices_; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/topk_mask.h b/infini_train/include/autograd/topk_mask.h deleted file mode 100644 index 355ef400..00000000 --- a/infini_train/include/autograd/topk_mask.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include - -#include "infini_train/include/autograd/function.h" - -namespace infini_train { -class Tensor; -} - -namespace infini_train::autograd { - -class TopKMask : public Function { -public: - static constexpr char kType[] = "TopKMaskFunction"; - - explicit TopKMask(int64_t topk) : Function(kType), topk_(topk) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - int64_t topk_ = 1; -}; - -} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/topk.cc b/infini_train/src/autograd/topk.cc new file mode 100644 index 00000000..4e0420b8 --- /dev/null +++ b/infini_train/src/autograd/topk.cc @@ -0,0 +1,39 @@ +#include "infini_train/include/autograd/topk.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> TopK::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + CHECK_GT(topk_, 0); + const auto &input = input_tensors[0]; + auto device = input->GetDevice().type(); + auto topk_outputs = Dispatcher::Instance().Call>>( + {device, "TopKForward"}, input, topk_, dim_, largest_, sorted_); + CHECK_EQ(topk_outputs.size(), 2); + top_indices_ = topk_outputs[1]; + return {topk_outputs[0]}; +} + +void TopK::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + input_dims_ = input_tensors[0]->Dims(); + saved_tensors_ = {top_indices_}; +} + +std::vector> TopK::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &top_grad = grad_outputs[0]; + const auto &top_indices = saved_tensors_[0]; + auto device = top_grad->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "TopKBackward"}, top_grad, top_indices, + input_dims_, dim_)}; +} + +std::shared_ptr TopK::TopIndices() const { return top_indices_; } + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/topk_mask.cc b/infini_train/src/autograd/topk_mask.cc deleted file mode 100644 index 16dc6629..00000000 --- a/infini_train/src/autograd/topk_mask.cc +++ /dev/null @@ -1,32 +0,0 @@ -#include "infini_train/include/autograd/topk_mask.h" - -#include "glog/logging.h" - -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -namespace infini_train::autograd { - -std::vector> TopKMask::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - CHECK_GT(topk_, 0); - const auto &input = input_tensors[0]; - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "TopKMaskForward"}, input, topk_)}; -} - -void TopKMask::SetupContext(const std::vector> &, - const std::vector> &output_tensors) { - saved_tensors_ = {output_tensors[0]}; -} - -std::vector> TopKMask::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - const auto &mask_values = saved_tensors_[0]; - auto device = grad_output->GetDevice().type(); - return { - Dispatcher::Instance().Call>({device, "TopKMaskBackward"}, grad_output, mask_values)}; -} - -} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/topk.cc b/infini_train/src/kernels/cpu/topk.cc new file mode 100644 index 00000000..9e191143 --- /dev/null +++ b/infini_train/src/kernels/cpu/topk.cc @@ -0,0 +1,124 @@ +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKForward currently supports float32 only"; + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + const float *in = static_cast(input->DataPtr()); + float *values = static_cast(top_values->DataPtr()); + int64_t *indices = static_cast(top_indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + std::vector selected_indices(dim_size, false); + for (int64_t selected = 0; selected < topk; ++selected) { + int64_t best_idx = -1; + float best_value + = largest ? -std::numeric_limits::infinity() : std::numeric_limits::infinity(); + for (int64_t idx = 0; idx < dim_size; ++idx) { + if (selected_indices[idx]) { + continue; + } + const float value = in[outer * dim_size * inner_size + idx * inner_size + inner]; + const bool better = largest ? value > best_value : value < best_value; + if (better) { + best_value = value; + best_idx = idx; + } + } + CHECK_GE(best_idx, 0); + selected_indices[best_idx] = true; + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + values[out_offset] = best_value; + indices[out_offset] = best_idx; + } + } + } + + return {top_values, top_indices}; +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + std::memset(grad_input->DataPtr(), 0, grad_input->SizeInBytes()); + + const size_t elem_size = kDataTypeToSize.at(grad_values->Dtype()); + const auto *src = static_cast(grad_values->DataPtr()); + auto *dst = static_cast(grad_input->DataPtr()); + const auto *idx_ptr = static_cast(indices->DataPtr()); + for (int64_t outer = 0; outer < outer_size; ++outer) { + for (int64_t inner = 0; inner < inner_size; ++inner) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = idx_ptr[out_offset]; + CHECK_GE(selected_idx, 0); + CHECK_LT(selected_idx, dim_size); + std::memcpy(dst + (outer * dim_size * inner_size + selected_idx * inner_size + inner) * elem_size, + src + out_offset * elem_size, elem_size); + } + } + } + + return grad_input; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_TOPK_KERNEL(TopKForward) +REGISTER_CPU_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CPU_TOPK_KERNEL diff --git a/infini_train/src/kernels/cpu/topk_mask.cc b/infini_train/src/kernels/cpu/topk_mask.cc deleted file mode 100644 index 6a7191b9..00000000 --- a/infini_train/src/kernels/cpu/topk_mask.cc +++ /dev/null @@ -1,88 +0,0 @@ -#include -#include -#include - -#include "glog/logging.h" - -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -namespace infini_train::kernels::cpu { - -std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { - CHECK(input->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskForward currently supports float32 only"; - CHECK_GE(input->Dims().size(), 1); - - const auto &dims = input->Dims(); - const int64_t num_experts = dims.back(); - CHECK_GT(num_experts, 0); - CHECK_GT(topk, 0); - CHECK_LE(topk, num_experts); - const int64_t rows = input->NumElements() / num_experts; - - auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); - output->Fill(0.0f); - - const float *in = static_cast(input->DataPtr()); - float *out = static_cast(output->DataPtr()); - for (int64_t row = 0; row < rows; ++row) { - const int64_t row_offset = row * num_experts; - std::vector selected_experts(num_experts, false); - float selected_sum = 0.0f; - for (int64_t selected = 0; selected < topk; ++selected) { - int64_t best_idx = -1; - float best_value = -std::numeric_limits::infinity(); - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - if (selected_experts[expert_idx]) { - continue; - } - const float value = in[row_offset + expert_idx]; - if (value > best_value) { - best_value = value; - best_idx = expert_idx; - } - } - CHECK_GE(best_idx, 0); - selected_experts[best_idx] = true; - out[row_offset + best_idx] = best_value; - selected_sum += best_value; - } - if (topk > 1 && selected_sum != 0.0f) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - out[row_offset + expert_idx] - = out[row_offset + expert_idx] == 0.0f ? 0.0f : out[row_offset + expert_idx] / selected_sum; - } - } - } - - return output; -} - -std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &mask_values) { - CHECK(grad_output->Dtype() == DataType::kFLOAT32) << "CPU TopKMaskBackward currently supports float32 only"; - CHECK(mask_values->Dtype() == DataType::kFLOAT32); - CHECK(grad_output->Dims() == mask_values->Dims()); - - auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); - grad_input->Fill(0.0f); - - const float *grad = static_cast(grad_output->DataPtr()); - const float *mask = static_cast(mask_values->DataPtr()); - float *out = static_cast(grad_input->DataPtr()); - for (int64_t i = 0; i < static_cast(grad_output->NumElements()); ++i) { - out[i] = mask[i] != 0.0f ? grad[i] : 0.0f; - } - - return grad_input; -} - -} // namespace infini_train::kernels::cpu - -#define REGISTER_CPU_TOPK_MASK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) - -REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskForward) -REGISTER_CPU_TOPK_MASK_KERNEL(TopKMaskBackward) - -#undef REGISTER_CPU_TOPK_MASK_KERNEL diff --git a/infini_train/src/kernels/cuda/topk.cu b/infini_train/src/kernels/cuda/topk.cu new file mode 100644 index 00000000..32044c3f --- /dev/null +++ b/infini_train/src/kernels/cuda/topk.cu @@ -0,0 +1,155 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void TopKForwardKernel(const T *__restrict__ input, T *__restrict__ top_values, + int64_t *__restrict__ top_indices, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk, bool largest) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t idx = 0; idx < dim_size; ++idx) { + const float value = static_cast(input[outer * dim_size * inner_size + idx * inner_size + inner]); + int64_t rank = 0; + for (int64_t other_idx = 0; other_idx < dim_size; ++other_idx) { + const float other_value + = static_cast(input[outer * dim_size * inner_size + other_idx * inner_size + inner]); + const bool ranks_before = largest ? (other_value > value || (other_value == value && other_idx < idx)) + : (other_value < value || (other_value == value && other_idx < idx)); + if (ranks_before) { + ++rank; + } + } + if (rank < topk) { + const int64_t out_offset = outer * topk * inner_size + rank * inner_size + inner; + top_values[out_offset] = input[outer * dim_size * inner_size + idx * inner_size + inner]; + top_indices[out_offset] = idx; + } + } +} + +std::vector> TopKForward(const std::shared_ptr &input, int64_t topk, int64_t dim, + bool largest, bool sorted) { + CHECK_GE(input->Dims().size(), 1); + (void)sorted; + const auto &dims = input->Dims(); + if (dim < 0) { + dim += static_cast(dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(dims.size())); + + const int64_t dim_size = dims[dim]; + CHECK_GT(dim_size, 0); + CHECK_GT(topk, 0); + CHECK_LE(topk, dim_size); + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < dims.size(); ++idx) { inner_size *= dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto topk_dims = dims; + topk_dims[dim] = topk; + auto top_values = std::make_shared(topk_dims, input->Dtype(), input->GetDevice()); + auto top_indices = std::make_shared(topk_dims, DataType::kINT64, input->GetDevice()); + + auto device = input->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + input->Dtype(), + [=]() { + TopKForwardKernel<<>>( + static_cast(input->DataPtr()), static_cast(top_values->DataPtr()), + static_cast(top_indices->DataPtr()), rows, dim_size, inner_size, topk, largest); + }, + "CUDA TopKForward"); + + return {top_values, top_indices}; +} + +template +__global__ void TopKBackwardKernel(const T *__restrict__ grad_values, const int64_t *__restrict__ indices, + T *__restrict__ grad_input, int64_t rows, int64_t dim_size, int64_t inner_size, + int64_t topk) { + int64_t row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= rows) { + return; + } + + const int64_t outer = row / inner_size; + const int64_t inner = row % inner_size; + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t out_offset = outer * topk * inner_size + selected * inner_size + inner; + const int64_t selected_idx = indices[out_offset]; + grad_input[outer * dim_size * inner_size + selected_idx * inner_size + inner] = grad_values[out_offset]; + } +} + +std::shared_ptr TopKBackward(const std::shared_ptr &grad_values, const std::shared_ptr &indices, + const std::vector &input_dims, int64_t dim) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA TopKBackward expects int64 indices"; + CHECK(grad_values->Dims() == indices->Dims()); + CHECK(!input_dims.empty()); + if (dim < 0) { + dim += static_cast(input_dims.size()); + } + CHECK_GE(dim, 0); + CHECK_LT(dim, static_cast(input_dims.size())); + + const int64_t dim_size = input_dims[dim]; + const int64_t topk = indices->Dims()[dim]; + int64_t outer_size = 1; + for (int64_t idx = 0; idx < dim; ++idx) { outer_size *= input_dims[idx]; } + int64_t inner_size = 1; + for (size_t idx = static_cast(dim) + 1; idx < input_dims.size(); ++idx) { inner_size *= input_dims[idx]; } + const int64_t rows = outer_size * inner_size; + + auto grad_input = std::make_shared(input_dims, grad_values->Dtype(), grad_values->GetDevice()); + auto device = grad_values->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMemsetAsync(grad_input->DataPtr(), 0, grad_input->SizeInBytes(), stream)); + + const int threads = 256; + const int blocks = static_cast((rows + threads - 1) / threads); + core::cuda::DispatchCudaFunc( + grad_values->Dtype(), + [=]() { + TopKBackwardKernel<<>>( + static_cast(grad_values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(grad_input->DataPtr()), rows, dim_size, inner_size, topk); + }, + "CUDA TopKBackward"); + + return grad_input; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_TOPK_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_TOPK_KERNEL(TopKForward) +REGISTER_CUDA_TOPK_KERNEL(TopKBackward) + +#undef REGISTER_CUDA_TOPK_KERNEL diff --git a/infini_train/src/kernels/cuda/topk_mask.cu b/infini_train/src/kernels/cuda/topk_mask.cu deleted file mode 100644 index e38c793e..00000000 --- a/infini_train/src/kernels/cuda/topk_mask.cu +++ /dev/null @@ -1,118 +0,0 @@ -#include "glog/logging.h" - -#include "infini_train/include/common/cuda/common_cuda.h" -#include "infini_train/include/core/runtime/device_guard.h" -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" -#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" - -namespace infini_train::kernels::cuda { - -template -__global__ void TopKMaskForwardKernel(const T *__restrict__ input, T *__restrict__ output, int64_t rows, - int64_t num_experts, int64_t topk) { - int64_t row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= rows) { - return; - } - - const int64_t offset = row * num_experts; - float selected_sum = 0.0f; - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - const float value = static_cast(input[offset + expert_idx]); - int64_t rank = 0; - for (int64_t other_idx = 0; other_idx < num_experts; ++other_idx) { - const float other_value = static_cast(input[offset + other_idx]); - if (other_value > value || (other_value == value && other_idx < expert_idx)) { - ++rank; - } - } - const bool selected = rank < topk; - output[offset + expert_idx] = selected ? input[offset + expert_idx] : T(0.0f); - selected_sum += selected ? value : 0.0f; - } - if (topk > 1 && selected_sum != 0.0f) { - for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - if (static_cast(output[offset + expert_idx]) != 0.0f) { - output[offset + expert_idx] = T(static_cast(output[offset + expert_idx]) / selected_sum); - } - } - } -} - -std::shared_ptr TopKMaskForward(const std::shared_ptr &input, int64_t topk) { - CHECK_GE(input->Dims().size(), 1); - const auto &dims = input->Dims(); - const int64_t num_experts = dims.back(); - CHECK_GT(num_experts, 0); - CHECK_GT(topk, 0); - CHECK_LE(topk, num_experts); - const int64_t rows = input->NumElements() / num_experts; - - auto output = std::make_shared(dims, input->Dtype(), input->GetDevice()); - - auto device = input->GetDevice(); - const auto &stream = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) - ->cuda_stream(); - const int threads = 256; - const int blocks = static_cast((rows + threads - 1) / threads); - - core::cuda::DispatchCudaFunc( - input->Dtype(), - [=]() { - TopKMaskForwardKernel<<>>( - static_cast(input->DataPtr()), static_cast(output->DataPtr()), rows, num_experts, topk); - }, - "CUDA TopKMaskForward"); - - return output; -} - -template -__global__ void TopKMaskBackwardKernel(const T *__restrict__ grad_output, const T *__restrict__ mask_values, - T *__restrict__ grad_input, int64_t total_elements) { - int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_elements) { - return; - } - grad_input[idx] = static_cast(mask_values[idx]) != 0.0f ? grad_output[idx] : T(0.0f); -} - -std::shared_ptr TopKMaskBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &mask_values) { - CHECK(grad_output->Dims() == mask_values->Dims()); - CHECK(grad_output->Dtype() == mask_values->Dtype()); - auto grad_input = std::make_shared(grad_output->Dims(), grad_output->Dtype(), grad_output->GetDevice()); - - auto device = grad_output->GetDevice(); - const auto &stream = dynamic_cast( - infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) - ->cuda_stream(); - const int64_t total_elements = grad_output->NumElements(); - const int threads = 256; - const int blocks = static_cast((total_elements + threads - 1) / threads); - - core::cuda::DispatchCudaFunc( - grad_output->Dtype(), - [=]() { - TopKMaskBackwardKernel<<>>( - static_cast(grad_output->DataPtr()), static_cast(mask_values->DataPtr()), - static_cast(grad_input->DataPtr()), total_elements); - }, - "CUDA TopKMaskBackward"); - - return grad_input; -} - -} // namespace infini_train::kernels::cuda - -#define REGISTER_CUDA_TOPK_MASK_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) - -REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskForward) -REGISTER_CUDA_TOPK_MASK_KERNEL(TopKMaskBackward) - -#undef REGISTER_CUDA_TOPK_MASK_KERNEL From abcfbef471a473b7b089097764516e060e56c93b Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 27 May 2026 09:28:27 +0000 Subject: [PATCH 08/11] refactor: refactor TopKRouter module to align with Megatron interface --- .../nn/modules/transformer/moe/moe_utils.h | 10 ++++ .../modules/transformer/transformer_config.h | 31 ++++++----- .../nn/modules/transformer/moe/moe_layer.cc | 9 ++-- .../nn/modules/transformer/moe/moe_utils.cc | 52 +++++++++++++++++++ .../src/nn/modules/transformer/moe/router.cc | 17 +++--- 5 files changed, 95 insertions(+), 24 deletions(-) diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h index e0dd3744..6ce26f44 100644 --- a/infini_train/include/nn/modules/transformer/moe/moe_utils.h +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -1,9 +1,19 @@ #pragma once +#include +#include +#include + #include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/tensor.h" namespace infini_train::nn::moe { +std::vector> TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, + bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function); + const MoEConfig &RequireMoEConfig(const TransformerConfig &config); } // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 3a96625d..8c440d16 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -30,30 +30,33 @@ enum class NormType { kRMSNorm // RMSNorm }; -enum class MoERouterType { - kTopK // Top-k router. -}; +struct MoEConfig { + enum class RouterScoreFunction { + kSoftmax, + kSigmoid, + }; -enum class MoEDispatcherType { - kLocal, // No cross-rank token exchange - kAllGather // Reserved for expert parallel MoE -}; + enum class DispatcherType { + kAllGather, // Megatron-style AllGather dispatcher. Degenerates to local dispatch when TP=EP=1. + kAllToAll // Megatron-style AllToAll dispatcher for expert parallel MoE. + }; -enum class MoEExpertImpl { - kSequential // Run local experts sequentially -}; + enum class ExpertImpl { + kSequential // Run local experts sequentially + }; -struct MoEConfig { int64_t num_experts = 0; int64_t expert_parallel_size = 1; int64_t router_topk = 1; + bool router_pre_softmax = false; + std::optional router_topk_scaling_factor = std::nullopt; + RouterScoreFunction router_score_function = RouterScoreFunction::kSoftmax; float aux_loss_coeff = 0.0f; std::optional expert_capacity_factor = std::nullopt; bool pad_expert_input_to_capacity = false; int64_t moe_ffn_hidden_size = 0; - MoERouterType router_type = MoERouterType::kTopK; - MoEDispatcherType dispatcher_type = MoEDispatcherType::kLocal; - MoEExpertImpl expert_impl = MoEExpertImpl::kSequential; + DispatcherType dispatcher_type = DispatcherType::kAllGather; + ExpertImpl expert_impl = ExpertImpl::kSequential; }; struct TransformerConfig { diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc index 8efd51c0..6add37ef 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -15,8 +15,8 @@ namespace infini_train::nn::moe { MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); CHECK(config_.ffn_type == FFNType::kMoE); - CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) - << "Current InfiniTrain MoE implementation supports local dispatch only"; + CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; modules_[kRouterLayerName] = std::make_shared(config_); modules_[kExpertsLayerName] = std::make_shared(config_); @@ -25,8 +25,9 @@ MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), co std::vector> MoELayer::Forward(const std::vector> &input_tensors) { CHECK_EQ(input_tensors.size(), 1); auto hidden_states = input_tensors[0]; - auto routing_probs = (*modules_.at(kRouterLayerName))({hidden_states})[0]; - return (*modules_.at(kExpertsLayerName))({hidden_states, routing_probs}); + auto router_output = (*modules_.at(kRouterLayerName))({hidden_states}); + CHECK_EQ(router_output.size(), 2); + return (*modules_.at(kExpertsLayerName))({hidden_states, router_output[0], router_output[1]}); } } // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc index 80ef01c1..976e9eff 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -2,8 +2,60 @@ #include "glog/logging.h" +#include "infini_train/include/autograd/local_token_dispatcher.h" +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/topk.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/nn/functional.h" + namespace infini_train::nn::moe { +std::vector> +TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, bool use_pre_softmax, + std::optional scaling_factor, + const MoEConfig::RouterScoreFunction &score_function) { + + // Megatron TopKRouter returns dense tensors: + // routing_probs: [num_tokens, num_experts] + // routing_map: [num_tokens, num_experts], bool + std::shared_ptr top_probs; + std::shared_ptr top_indices; + + if (score_function == MoEConfig::RouterScoreFunction::kSoftmax) { + if (use_pre_softmax) { + auto scores = function::Softmax(logits, -1); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({scores})[0]; + top_indices = topk_function->TopIndices(); + } else { + auto topk_function = std::make_shared(topk); + auto top_scores = topk_function->Apply({logits})[0]; + top_indices = topk_function->TopIndices(); + top_probs = function::Softmax(top_scores, -1); + } + } else if (score_function == MoEConfig::RouterScoreFunction::kSigmoid) { + auto sigmoid_scores = function::Sigmoid(logits); + auto topk_function = std::make_shared(topk); + top_probs = topk_function->Apply({sigmoid_scores})[0]; + top_indices = topk_function->TopIndices(); + if (topk > 1) { + top_probs = top_probs / (top_probs->Sum(-1, true) + 1e-20f); + } + } else { + LOG(FATAL) << "Unsupported MoE router score function"; + } + + if (scaling_factor.has_value()) { + top_probs = top_probs * scaling_factor.value(); + } + + auto routing_probs = std::make_shared(logits->Dims())->Apply({top_probs, top_indices})[0]; + auto routing_map_values = std::make_shared(top_indices->Equals(top_indices)->To(DataType::kBOOL)); + auto routing_map = Dispatcher::Instance().Call>( + {logits->GetDevice().type(), "ScatterForward"}, routing_map_values, top_indices, logits->Dims()); + return {routing_probs, routing_map}; +} + const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { CHECK(config.moe_config.has_value()) << "MoE layer requires TransformerConfig::moe_config"; return config.moe_config.value(); diff --git a/infini_train/src/nn/modules/transformer/moe/router.cc b/infini_train/src/nn/modules/transformer/moe/router.cc index 851c57be..25208684 100644 --- a/infini_train/src/nn/modules/transformer/moe/router.cc +++ b/infini_train/src/nn/modules/transformer/moe/router.cc @@ -6,7 +6,8 @@ #include "glog/logging.h" #include "infini_train/include/autograd/linear.h" -#include "infini_train/include/autograd/topk_mask.h" +#include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/topk.h" #include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" @@ -16,11 +17,9 @@ namespace infini_train::nn::moe { TopKRouter::TopKRouter(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); - CHECK(moe_config.router_type == MoERouterType::kTopK); CHECK_GT(moe_config.num_experts, 0); CHECK_GT(moe_config.router_topk, 0); CHECK_LE(moe_config.router_topk, moe_config.num_experts); - parameters_[kParamWeightName] = std::make_shared(std::vector{moe_config.num_experts, config_.n_embd}, DataType::kFLOAT32, device_) @@ -43,10 +42,16 @@ std::vector> TopKRouter::Forward(const std::vector()->Apply(linear_inputs)[0]; - auto scores = function::Softmax(logits, -1); + const auto &moe_config = RequireMoEConfig(config_); - auto routing_probs = std::make_shared(moe_config.router_topk)->Apply({scores})[0]; - return {routing_probs}; + + auto routing_results + = TopkRoutingWithScoreFunction(logits, moe_config.router_topk, moe_config.router_pre_softmax, + moe_config.router_topk_scaling_factor, moe_config.router_score_function); + + auto routing_probs = routing_results[0]; + auto routing_map = routing_results[1]; + return {routing_probs, routing_map}; } } // namespace infini_train::nn::moe From 69a74fd8e65e5d16eac4db7d79e625233a92104d Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 29 May 2026 09:33:12 +0000 Subject: [PATCH 09/11] feat: implement MoETokenDispatcher base class and MoEAllGatherTokenDispatcher --- infini_train/include/autograd/scatter_add.h | 31 +++++ .../nn/modules/transformer/moe/moe_utils.h | 21 ++++ .../transformer/moe/token_dispatcher.h | 67 ++++++++++ infini_train/src/autograd/scatter_add.cc | 35 ++++++ infini_train/src/kernels/cpu/concat.cc | 16 +-- infini_train/src/kernels/cpu/transform.cc | 11 +- .../src/nn/modules/transformer/moe/experts.cc | 39 ++++-- .../nn/modules/transformer/moe/moe_utils.cc | 118 +++++++++++++++++- .../transformer/moe/token_dispatcher.cc | 95 ++++++++++++++ 9 files changed, 408 insertions(+), 25 deletions(-) create mode 100644 infini_train/include/autograd/scatter_add.h create mode 100644 infini_train/include/nn/modules/transformer/moe/token_dispatcher.h create mode 100644 infini_train/src/autograd/scatter_add.cc create mode 100644 infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc diff --git a/infini_train/include/autograd/scatter_add.h b/infini_train/include/autograd/scatter_add.h new file mode 100644 index 00000000..3adc1586 --- /dev/null +++ b/infini_train/include/autograd/scatter_add.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class ScatterAdd : public Function { +public: + static constexpr char kType[] = "ScatterAddFunction"; + + ScatterAdd(int64_t dim, const std::vector &output_dims) + : Function(kType), dim_(dim), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t dim_ = 0; + std::vector output_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/nn/modules/transformer/moe/moe_utils.h b/infini_train/include/nn/modules/transformer/moe/moe_utils.h index 6ce26f44..f6941049 100644 --- a/infini_train/include/nn/modules/transformer/moe/moe_utils.h +++ b/infini_train/include/nn/modules/transformer/moe/moe_utils.h @@ -9,11 +9,32 @@ namespace infini_train::nn::moe { +struct PermutationMetadata { + std::shared_ptr sorted_indices; + std::shared_ptr gather_indices; + std::shared_ptr route_indices; + std::shared_ptr tokens_per_expert; + std::vector tokens_per_expert_host; +}; + +struct PermutationResult { + std::shared_ptr permuted_hidden_states; + std::shared_ptr permuted_probs; + PermutationMetadata metadata; +}; + std::vector> TopkRoutingWithScoreFunction(const std::shared_ptr &logits, int64_t topk, bool use_pre_softmax, std::optional scaling_factor, const MoEConfig::RouterScoreFunction &score_function); const MoEConfig &RequireMoEConfig(const TransformerConfig &config); +PermutationMetadata BuildPermutationMetadata(const std::shared_ptr &routing_map); +PermutationResult Permute(const std::shared_ptr &hidden_states_2d, + const std::shared_ptr &routing_probs_2d, + const std::shared_ptr &routing_map_2d); +std::shared_ptr Unpermute(const std::shared_ptr &permuted_hidden_states, + const std::shared_ptr &permuted_probs, const PermutationMetadata &metadata, + const std::vector &restore_shape); } // namespace infini_train::nn::moe diff --git a/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h b/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h new file mode 100644 index 00000000..f9e3c614 --- /dev/null +++ b/infini_train/include/nn/modules/transformer/moe/token_dispatcher.h @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/transformer_config.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::moe { + +class MoETokenDispatcher { +public: + virtual ~MoETokenDispatcher() = default; + + const PermutationResult &Dispatch(const std::shared_ptr &tokens, const std::shared_ptr &routing_map, + const std::shared_ptr &probs); + std::shared_ptr Combine(const std::shared_ptr &hidden_states) const; + +protected: + explicit MoETokenDispatcher(const TransformerConfig &config); + + virtual std::vector> DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) + = 0; + virtual std::vector> TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const + = 0; + virtual const PermutationResult &DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) + = 0; + virtual std::shared_ptr CombinePreprocess(const std::shared_ptr &hidden_states) const = 0; + virtual std::shared_ptr TokenCombine(const std::shared_ptr &hidden_states) const = 0; + virtual std::shared_ptr CombinePostprocess(const std::shared_ptr &hidden_states) const = 0; + + TransformerConfig config_; + PermutationResult dispatch_; + std::vector hidden_dims_; + std::shared_ptr routing_map_; + std::shared_ptr local_map_; + std::shared_ptr local_probs_; + int64_t num_tokens_ = 0; + int64_t hidden_size_ = 0; +}; + +class MoEAllGatherTokenDispatcher : public MoETokenDispatcher { +public: + MoEAllGatherTokenDispatcher(int64_t num_local_experts, const TransformerConfig &config); + +private: + std::vector> DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) override; + std::vector> TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const override; + const PermutationResult &DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) override; + std::shared_ptr CombinePreprocess(const std::shared_ptr &hidden_states) const override; + std::shared_ptr TokenCombine(const std::shared_ptr &hidden_states) const override; + std::shared_ptr CombinePostprocess(const std::shared_ptr &hidden_states) const override; + + int64_t num_local_experts_ = 0; +}; + +} // namespace infini_train::nn::moe diff --git a/infini_train/src/autograd/scatter_add.cc b/infini_train/src/autograd/scatter_add.cc new file mode 100644 index 00000000..428f4f08 --- /dev/null +++ b/infini_train/src/autograd/scatter_add.cc @@ -0,0 +1,35 @@ +#include "infini_train/include/autograd/scatter_add.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> ScatterAdd::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &values = input_tensors[0]; + const auto &indices = input_tensors[1]; + auto device = values->GetDevice().type(); + auto output = Dispatcher::Instance().Call>({device, "GatherBackward"}, values, indices, + dim_, output_dims_); + return {output}; +} + +void ScatterAdd::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + saved_tensors_ = {input_tensors[1]}; +} + +std::vector> ScatterAdd::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &indices = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + auto grad_values + = Dispatcher::Instance().Call>({device, "GatherForward"}, grad_output, indices, dim_); + return {grad_values, nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/concat.cc b/infini_train/src/kernels/cpu/concat.cc index b421063f..169cc557 100644 --- a/infini_train/src/kernels/cpu/concat.cc +++ b/infini_train/src/kernels/cpu/concat.cc @@ -1,7 +1,6 @@ -#include +#include #include #include -#include #include #include "glog/logging.h" @@ -42,23 +41,24 @@ std::shared_ptr ConcatForward(const std::vector> const int64_t K_total = std::accumulate(Ks.begin(), Ks.end(), int64_t{0}); output_dims[dim] = K_total; - auto output = std::make_shared(output_dims, DataType::kFLOAT32); + auto output = std::make_shared(output_dims, dtype, device); const int64_t outer_size = std::accumulate(output_dims.begin(), output_dims.begin() + dim, 1LL, std::multiplies()); const int64_t inner_size = std::accumulate(output_dims.begin() + dim + 1, output_dims.end(), 1LL, std::multiplies()); - const size_t elem_size = sizeof(float); + const size_t elem_size = kDataTypeToSize.at(dtype); - float *dst_ptr_base = static_cast(output->DataPtr()); + auto *dst_ptr_base = static_cast(output->DataPtr()); for (int64_t n = 0; n < outer_size; ++n) { int64_t offset_k = 0; - float *dst_block = dst_ptr_base + n * K_total * inner_size; + auto *dst_block = dst_ptr_base + n * K_total * inner_size * elem_size; for (size_t i = 0; i < inputs.size(); ++i) { const int64_t Ki = Ks[i]; - const float *src_ptr = static_cast(inputs[i]->DataPtr()) + n * Ki * inner_size; - float *dst_ptr = dst_block + offset_k * inner_size; + const auto *src_ptr + = static_cast(inputs[i]->DataPtr()) + n * Ki * inner_size * elem_size; + auto *dst_ptr = dst_block + offset_k * inner_size * elem_size; std::memcpy(dst_ptr, src_ptr, static_cast(Ki) * inner_size * elem_size); offset_k += Ki; } diff --git a/infini_train/src/kernels/cpu/transform.cc b/infini_train/src/kernels/cpu/transform.cc index 1a810b44..48063c7a 100644 --- a/infini_train/src/kernels/cpu/transform.cc +++ b/infini_train/src/kernels/cpu/transform.cc @@ -1,4 +1,6 @@ #include +#include +#include #include #include "glog/logging.h" @@ -167,14 +169,15 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i output_dims[dim] = dim_size * repeat; auto output = std::make_shared(output_dims, input->Dtype(), input->GetDevice()); - const float *input_ptr = static_cast(input->DataPtr()); - float *output_ptr = static_cast(output->DataPtr()); + const size_t elem_size = kDataTypeToSize.at(input->Dtype()); + const auto *input_ptr = static_cast(input->DataPtr()); + auto *output_ptr = static_cast(output->DataPtr()); for (int64_t o = 0; o < outer; ++o) { for (int64_t i = 0; i < dim_size; ++i) { for (int r = 0; r < repeat; ++r) { - std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner), - input_ptr + ((o * dim_size + i) * inner), sizeof(float) * inner); + std::memcpy(output_ptr + ((o * dim_size * repeat + i * repeat + r) * inner * elem_size), + input_ptr + ((o * dim_size + i) * inner * elem_size), elem_size * inner); } } } diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc index 8f3b1be8..7566c48f 100644 --- a/infini_train/src/nn/modules/transformer/moe/experts.cc +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -6,19 +6,21 @@ #include "glog/logging.h" +#include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/modules/transformer/mlp.h" #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include "infini_train/include/nn/modules/transformer/moe/token_dispatcher.h" #include "infini_train/include/tensor.h" namespace infini_train::nn::moe { SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); - CHECK(moe_config.expert_impl == MoEExpertImpl::kSequential); + CHECK(moe_config.expert_impl == MoEConfig::ExpertImpl::kSequential); CHECK_EQ(moe_config.expert_parallel_size, 1) << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; - CHECK(moe_config.dispatcher_type == MoEDispatcherType::kLocal) - << "Current InfiniTrain MoE implementation supports local dispatch only"; + CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; num_local_experts_ = moe_config.num_experts; CHECK_GT(num_local_experts_, 0); @@ -29,22 +31,35 @@ SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule( } std::vector> SequentialMLP::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 2); + CHECK_EQ(input_tensors.size(), 3); auto hidden_states = input_tensors[0]; auto routing_probs = input_tensors[1]; - CHECK_EQ(routing_probs->Dims().back(), num_local_experts_); + auto routing_map = input_tensors[2]; + std::unique_ptr dispatcher + = std::make_unique(num_local_experts_, config_); + const auto &dispatch = dispatcher->Dispatch(hidden_states, routing_map, routing_probs); - std::shared_ptr output = nullptr; - const int64_t expert_dim = static_cast(routing_probs->Dims().size()) - 1; + std::vector> expert_outputs; + int64_t start = 0; for (int64_t expert_idx = 0; expert_idx < num_local_experts_; ++expert_idx) { + const int64_t num_tokens_for_expert = dispatch.metadata.tokens_per_expert_host[expert_idx]; + const int64_t end = start + num_tokens_for_expert; + if (num_tokens_for_expert == 0) { + start = end; + continue; + } + + auto expert_input = dispatch.permuted_hidden_states->Slice(0, start, end); auto expert_name = std::string(kExpertNamePrefix) + std::to_string(expert_idx); - auto expert_output = (*modules_.at(expert_name))({hidden_states})[0]; - auto expert_prob = routing_probs->Slice(expert_dim, expert_idx, expert_idx + 1); - auto weighted_output = expert_output * expert_prob; - output = output == nullptr ? weighted_output : output + weighted_output; + expert_outputs.push_back((*modules_.at(expert_name))({expert_input})[0]); + start = end; } + CHECK_EQ(start, dispatch.permuted_hidden_states->Dims()[0]); + CHECK(!expert_outputs.empty()) << "No tokens were dispatched to any local expert"; - return {output}; + auto permuted_expert_output + = expert_outputs.size() == 1 ? expert_outputs[0] : nn::function::Concat(expert_outputs, 0); + return {dispatcher->Combine(permuted_expert_output)}; } } // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc index 976e9eff..040b29df 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_utils.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_utils.cc @@ -1,9 +1,11 @@ #include "infini_train/include/nn/modules/transformer/moe/moe_utils.h" +#include + #include "glog/logging.h" -#include "infini_train/include/autograd/local_token_dispatcher.h" #include "infini_train/include/autograd/scatter.h" +#include "infini_train/include/autograd/scatter_add.h" #include "infini_train/include/autograd/topk.h" #include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/functional.h" @@ -61,4 +63,118 @@ const MoEConfig &RequireMoEConfig(const TransformerConfig &config) { return config.moe_config.value(); } +PermutationMetadata BuildPermutationMetadata(const std::shared_ptr &routing_map) { + CHECK(routing_map->Dtype() == DataType::kBOOL); + CHECK_EQ(routing_map->Dims().size(), 2); + + const int64_t num_tokens = routing_map->Dims()[0]; + const int64_t num_experts = routing_map->Dims()[1]; + CHECK_GT(num_tokens, 0); + CHECK_GT(num_experts, 0); + + Tensor routing_map_cpu_storage = routing_map->To(Device()); + auto routing_map_cpu = std::make_shared(routing_map_cpu_storage); + const auto *routing_map_ptr = static_cast(routing_map_cpu->DataPtr()); + + std::vector sorted_indices_host; + std::vector route_indices_host; + std::vector tokens_per_expert_host; + sorted_indices_host.reserve(routing_map->NumElements()); + route_indices_host.reserve(routing_map->NumElements()); + tokens_per_expert_host.reserve(num_experts); + + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + int64_t tokens_for_expert = 0; + for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + if (routing_map_ptr[token_idx * num_experts + expert_idx]) { + sorted_indices_host.push_back(token_idx); + route_indices_host.push_back(token_idx * num_experts + expert_idx); + ++tokens_for_expert; + } + } + tokens_per_expert_host.push_back(tokens_for_expert); + } + + const int64_t num_dispatched_tokens = static_cast(sorted_indices_host.size()); + auto sorted_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens}, DataType::kINT64, Device()); + auto route_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens}, DataType::kINT64, Device()); + auto gather_indices_cpu + = std::make_shared(std::vector{num_dispatched_tokens, 1}, DataType::kINT64, Device()); + auto tokens_per_expert_cpu + = std::make_shared(std::vector{num_experts}, DataType::kINT64, Device()); + + auto *sorted_indices_ptr = static_cast(sorted_indices_cpu->DataPtr()); + auto *route_indices_ptr = static_cast(route_indices_cpu->DataPtr()); + auto *gather_indices_ptr = static_cast(gather_indices_cpu->DataPtr()); + auto *tokens_per_expert_ptr = static_cast(tokens_per_expert_cpu->DataPtr()); + for (int64_t idx = 0; idx < num_dispatched_tokens; ++idx) { + sorted_indices_ptr[idx] = sorted_indices_host[idx]; + route_indices_ptr[idx] = route_indices_host[idx]; + gather_indices_ptr[idx] = sorted_indices_host[idx]; + } + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + tokens_per_expert_ptr[expert_idx] = tokens_per_expert_host[expert_idx]; + } + + auto to_device = [&](const std::shared_ptr &cpu_tensor) -> std::shared_ptr { + if (routing_map->GetDevice().type() == Device::DeviceType::kCPU) { + return cpu_tensor; + } + return std::make_shared(cpu_tensor->To(routing_map->GetDevice())); + }; + + return {to_device(sorted_indices_cpu), to_device(gather_indices_cpu), to_device(route_indices_cpu), + to_device(tokens_per_expert_cpu), tokens_per_expert_host}; +} + +PermutationResult Permute(const std::shared_ptr &hidden_states_2d, + const std::shared_ptr &routing_probs_2d, + const std::shared_ptr &routing_map_2d) { + CHECK_EQ(hidden_states_2d->Dims().size(), 2); + CHECK(routing_probs_2d->Dims() == routing_map_2d->Dims()); + CHECK(routing_map_2d->Dtype() == DataType::kBOOL); + + const int64_t hidden_size = hidden_states_2d->Dims()[1]; + auto metadata = BuildPermutationMetadata(routing_map_2d); + const int64_t num_dispatched_tokens = metadata.sorted_indices->Dims()[0]; + + std::shared_ptr permuted_hidden_states; + std::shared_ptr permuted_probs; + if (num_dispatched_tokens == 0) { + permuted_hidden_states = std::make_shared(std::vector{0, hidden_size}, + hidden_states_2d->Dtype(), hidden_states_2d->GetDevice()); + permuted_probs = std::make_shared(std::vector{0}, routing_probs_2d->Dtype(), + routing_probs_2d->GetDevice()); + } else { + auto gather_indices = metadata.gather_indices; + if (hidden_size != 1) { + gather_indices = metadata.gather_indices->RepeatInterleave(hidden_size, 1); + } + permuted_hidden_states = hidden_states_2d->Gather(0, gather_indices); + permuted_probs = routing_probs_2d->View({static_cast(routing_probs_2d->NumElements())}) + ->Gather(0, metadata.route_indices); + } + + return {permuted_hidden_states, permuted_probs, metadata}; +} + +std::shared_ptr Unpermute(const std::shared_ptr &permuted_hidden_states, + const std::shared_ptr &permuted_probs, const PermutationMetadata &metadata, + const std::vector &restore_shape) { + CHECK_EQ(permuted_hidden_states->Dims().size(), 2); + CHECK_EQ(permuted_probs->Dims().size(), 1); + CHECK_EQ(permuted_hidden_states->Dims()[0], permuted_probs->Dims()[0]); + CHECK_EQ(restore_shape.size(), 2); + + auto weighted = permuted_hidden_states * permuted_probs->View({permuted_probs->Dims()[0], 1}); + auto scatter_indices = metadata.gather_indices; + const int64_t hidden_size = restore_shape[1]; + if (hidden_size != 1) { + scatter_indices = metadata.gather_indices->RepeatInterleave(hidden_size, 1); + } + return std::make_shared(0, restore_shape)->Apply({weighted, scatter_indices})[0]; +} + } // namespace infini_train::nn::moe diff --git a/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc b/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc new file mode 100644 index 00000000..667dba8f --- /dev/null +++ b/infini_train/src/nn/modules/transformer/moe/token_dispatcher.cc @@ -0,0 +1,95 @@ +#include "infini_train/include/nn/modules/transformer/moe/token_dispatcher.h" + +#include +#include + +#include "glog/logging.h" + +namespace infini_train::nn::moe { + +MoETokenDispatcher::MoETokenDispatcher(const TransformerConfig &config) : config_(config) {} + +const PermutationResult &MoETokenDispatcher::Dispatch(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) { + auto preprocessed = DispatchPreprocess(tokens, routing_map, probs); + auto dispatched = TokenDispatch(preprocessed[0], preprocessed[1]); + return DispatchPostprocess(dispatched[0], dispatched[1]); +} + +std::shared_ptr MoETokenDispatcher::Combine(const std::shared_ptr &hidden_states) const { + auto preprocessed = CombinePreprocess(hidden_states); + auto combined = TokenCombine(preprocessed); + return CombinePostprocess(combined); +} + +MoEAllGatherTokenDispatcher::MoEAllGatherTokenDispatcher(int64_t num_local_experts, const TransformerConfig &config) + : MoETokenDispatcher(config), num_local_experts_(num_local_experts) { + CHECK_GT(num_local_experts_, 0); +} + +std::vector> +MoEAllGatherTokenDispatcher::DispatchPreprocess(const std::shared_ptr &tokens, + const std::shared_ptr &routing_map, + const std::shared_ptr &probs) { + CHECK(probs->Dims() == routing_map->Dims()); + CHECK(routing_map->Dtype() == DataType::kBOOL); + CHECK_GE(tokens->Dims().size(), 2); + + hidden_dims_ = tokens->Dims(); + hidden_size_ = hidden_dims_.back(); + CHECK_GT(hidden_size_, 0); + num_tokens_ = tokens->NumElements() / hidden_size_; + CHECK_EQ(probs->Dims().back(), num_local_experts_); + CHECK_EQ(probs->NumElements(), static_cast(num_tokens_ * num_local_experts_)); + + routing_map_ = routing_map->View({num_tokens_, num_local_experts_}); + auto hidden_states_2d = tokens->View({num_tokens_, hidden_size_}); + auto probs_2d = probs->View({num_tokens_, num_local_experts_}); + return {hidden_states_2d, probs_2d}; +} + +std::vector> +MoEAllGatherTokenDispatcher::TokenDispatch(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) const { + // AllGather dispatcher will gather tokens across TP*EP ranks here. For the current single-rank + // path (tp_size=1, ep_size=1), no communication is required. + return {hidden_states, probs}; +} + +const PermutationResult &MoEAllGatherTokenDispatcher::DispatchPostprocess(const std::shared_ptr &hidden_states, + const std::shared_ptr &probs) { + CHECK(routing_map_ != nullptr); + CHECK_EQ(hidden_states->Dims().size(), 2); + CHECK_EQ(probs->Dims().size(), 2); + CHECK_EQ(hidden_states->Dims()[0], probs->Dims()[0]); + CHECK_EQ(probs->Dims()[1], num_local_experts_); + + // With ep_size=1 all experts are local, so the local expert map/probs are the gathered map/probs. + // Future EP support should slice [local_expert_start, local_expert_end) after AllGather. + local_map_ = routing_map_; + local_probs_ = probs; + dispatch_ = Permute(hidden_states, local_probs_, local_map_); + routing_map_ = nullptr; + return dispatch_; +} + +std::shared_ptr +MoEAllGatherTokenDispatcher::CombinePreprocess(const std::shared_ptr &hidden_states) const { + CHECK(local_map_ != nullptr); + CHECK(local_probs_ != nullptr); + return Unpermute(hidden_states, dispatch_.permuted_probs, dispatch_.metadata, + std::vector{num_tokens_, hidden_size_}); +} + +std::shared_ptr MoEAllGatherTokenDispatcher::TokenCombine(const std::shared_ptr &hidden_states) const { + // AllGather dispatcher will reduce-scatter combined token outputs here. For ep_size=1 this is a no-op. + return hidden_states; +} + +std::shared_ptr +MoEAllGatherTokenDispatcher::CombinePostprocess(const std::shared_ptr &hidden_states) const { + return hidden_states->View(hidden_dims_); +} + +} // namespace infini_train::nn::moe From a8942640202a7a0a16dc0cf627e9b1143f6948aa Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 2 Jun 2026 07:46:09 +0000 Subject: [PATCH 10/11] feat: add tiny_mixtral example --- CMakeLists.txt | 8 + example/tiny_mixtral/checkpoint_loader.cc | 167 ++++++++++++++++++ example/tiny_mixtral/checkpoint_loader.h | 21 +++ example/tiny_mixtral/config.h | 76 ++++++++ example/tiny_mixtral/main.cc | 136 ++++++++++++++ .../modules/transformer/transformer_config.h | 4 +- .../src/nn/modules/transformer/moe/experts.cc | 2 +- .../nn/modules/transformer/moe/moe_layer.cc | 2 +- 8 files changed, 412 insertions(+), 4 deletions(-) create mode 100644 example/tiny_mixtral/checkpoint_loader.cc create mode 100644 example/tiny_mixtral/checkpoint_loader.h create mode 100644 example/tiny_mixtral/config.h create mode 100644 example/tiny_mixtral/main.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c6da822..ac23a8b4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -199,6 +199,14 @@ add_executable(gpt2 ) link_infini_train_exe(gpt2) +add_executable(tiny_mixtral + example/tiny_mixtral/main.cc + example/common/tiny_shakespeare_dataset.cc + example/common/utils.cc + example/tiny_mixtral/checkpoint_loader.cc +) +link_infini_train_exe(tiny_mixtral) + add_executable(llama3 example/llama3/main.cc example/common/tiny_shakespeare_dataset.cc diff --git a/example/tiny_mixtral/checkpoint_loader.cc b/example/tiny_mixtral/checkpoint_loader.cc new file mode 100644 index 00000000..1e27ac53 --- /dev/null +++ b/example/tiny_mixtral/checkpoint_loader.cc @@ -0,0 +1,167 @@ +#include "example/tiny_mixtral/checkpoint_loader.h" + +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/datatype.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/tensor.h" + +#include "example/common/utils.h" +#include "example/tiny_mixtral/config.h" + +namespace nn = infini_train::nn; + +namespace { + +constexpr int32_t kTinyMixtralLLMCMagic = 20260513; +constexpr int32_t kTinyMixtralLLMCVersion = 2; +constexpr int64_t kLLMCHeaderEntries = 256; + +} // namespace + +namespace tiny_mixtral { + +namespace { + +template +void CompareCheckpointValue(const std::string &name, const T &checkpoint_value, const T &runtime_value) { + CHECK_EQ(checkpoint_value, runtime_value) << name << " value from checkpoint (" << checkpoint_value + << ") is not equal to runtime config value (" << runtime_value << ")"; +} + +} // namespace + +nn::TransformerConfig ConfigFromLLMC(const std::string &filepath) { + std::ifstream ifs(filepath, std::ios::binary); + CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath; + const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs); + CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath; + CHECK_EQ(infini_train::BytesToType(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic); + CHECK_EQ(infini_train::BytesToType(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion); + + auto config = TinyMixtralConfig(); + config.block_size = infini_train::BytesToType(header, 2 * sizeof(int32_t)); + config.vocab_size = infini_train::BytesToType(header, 3 * sizeof(int32_t)); + config.original_vocab_size = config.vocab_size; + config.n_layer = infini_train::BytesToType(header, 4 * sizeof(int32_t)); + config.n_head = infini_train::BytesToType(header, 5 * sizeof(int32_t)); + config.n_kv_head = infini_train::BytesToType(header, 6 * sizeof(int32_t)); + config.n_embd = infini_train::BytesToType(header, 7 * sizeof(int32_t)); + config.ffn_expansion_ratio = infini_train::BytesToType(header, 9 * sizeof(int32_t)); + // Header slots 10 and 11 store dense-MLP helpers; MoE expert size is stored in moe_ffn_hidden_size. + config.norm_eps = infini_train::BytesToType(header, 12 * sizeof(int32_t)); + config.rope_theta = infini_train::BytesToType(header, 13 * sizeof(int32_t)); + config.use_scaled_rope = infini_train::BytesToType(header, 14 * sizeof(int32_t)) != 0; + + nn::MoEConfig moe_config; + moe_config.num_experts = infini_train::BytesToType(header, 8 * sizeof(int32_t)); + moe_config.expert_parallel_size = 1; + moe_config.router_topk = infini_train::BytesToType(header, 15 * sizeof(int32_t)); + moe_config.moe_ffn_hidden_size = infini_train::BytesToType(header, 16 * sizeof(int32_t)); + moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; + moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; + config.moe_config = moe_config; + SanitizeTinyMixtralConfig(config); + return config; +} + +void CheckLLMCConfig(const std::string &filepath, const nn::TransformerConfig &expected_config) { + SanitizeTinyMixtralConfig(expected_config); + const auto checkpoint_config = ConfigFromLLMC(filepath); + CompareCheckpointValue("block_size", checkpoint_config.block_size, expected_config.block_size); + CompareCheckpointValue("vocab_size", checkpoint_config.vocab_size, expected_config.vocab_size); + CompareCheckpointValue("original_vocab_size", checkpoint_config.original_vocab_size, + expected_config.original_vocab_size); + CompareCheckpointValue("n_layer", checkpoint_config.n_layer, expected_config.n_layer); + CompareCheckpointValue("n_head", checkpoint_config.n_head, expected_config.n_head); + CompareCheckpointValue("n_kv_head", checkpoint_config.n_kv_head, expected_config.n_kv_head); + CompareCheckpointValue("n_embd", checkpoint_config.n_embd, expected_config.n_embd); + CompareCheckpointValue("ffn_expansion_ratio", checkpoint_config.ffn_expansion_ratio, + expected_config.ffn_expansion_ratio); + CompareCheckpointValue("norm_eps", checkpoint_config.norm_eps, expected_config.norm_eps); + CompareCheckpointValue("rope_theta", checkpoint_config.rope_theta, expected_config.rope_theta); + CompareCheckpointValue("use_scaled_rope", checkpoint_config.use_scaled_rope, expected_config.use_scaled_rope); + + CHECK(expected_config.moe_config.has_value()) << "tiny Mixtral runtime config requires MoE config"; + const auto &checkpoint_moe = checkpoint_config.moe_config.value(); + const auto &expected_moe = expected_config.moe_config.value(); + CompareCheckpointValue("num_experts", checkpoint_moe.num_experts, expected_moe.num_experts); + CompareCheckpointValue("router_topk", checkpoint_moe.router_topk, expected_moe.router_topk); + CompareCheckpointValue("moe_ffn_hidden_size", checkpoint_moe.moe_ffn_hidden_size, expected_moe.moe_ffn_hidden_size); +} + +std::shared_ptr LoadFromLLMC(const std::string &filepath, + const nn::TransformerConfig &expected_config) { + CheckLLMCConfig(filepath, expected_config); + auto model = std::make_shared(expected_config); + + std::ifstream ifs(filepath, std::ios::binary); + CHECK(ifs) << "Failed to open tiny Mixtral LLMC file: " << filepath; + const auto header = infini_train::ReadSeveralBytesFromIfstream(kLLMCHeaderEntries * sizeof(int32_t), &ifs); + CHECK(ifs) << "Failed to read tiny Mixtral LLMC header: " << filepath; + CHECK_EQ(infini_train::BytesToType(header, 0 * sizeof(int32_t)), kTinyMixtralLLMCMagic); + CHECK_EQ(infini_train::BytesToType(header, 1 * sizeof(int32_t)), kTinyMixtralLLMCVersion); + + const auto &config = expected_config; + auto state = model->StateDict(); + auto read_tensor_by_state_key = [&](const std::string &name) { + CHECK(state.contains(name)) << "Model state_dict does not contain " << name; + std::shared_ptr tensor = state.at(name); + CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32) + << "Only float32 tiny Mixtral LLMC files are supported: " << name; + infini_train::ReadMatrixAllFloat(ifs, static_cast(tensor->DataPtr()), tensor->NumElements(), 1); + CHECK(ifs) << "Failed to read tensor " << name; + }; + + auto read_projection_into_packed_qkv = [&](const std::string &packed_qkv_name, int64_t row_offset, int64_t num_rows, + const std::string &projection_name) { + CHECK(state.contains(packed_qkv_name)) << "Model state_dict does not contain " << packed_qkv_name; + std::shared_ptr tensor = state.at(packed_qkv_name); + CHECK(tensor->Dtype() == infini_train::DataType::kFLOAT32) + << "Only float32 tiny Mixtral LLMC files are supported: " << projection_name; + CHECK_EQ(tensor->Dims().size(), 2); + CHECK_GE(row_offset, 0); + CHECK_GT(num_rows, 0); + CHECK_LE(row_offset + num_rows, tensor->Dims()[0]); + const int64_t cols = tensor->Dims()[1]; + auto *data = static_cast(tensor->DataPtr()) + row_offset * cols; + infini_train::ReadMatrixAllFloat(ifs, data, num_rows, cols); + CHECK(ifs) << "Failed to read tensor rows " << projection_name; + }; + + const auto &moe_config = config.moe_config.value(); + read_tensor_by_state_key("transformer.wte.weight"); + for (int64_t layer = 0; layer < config.n_layer; ++layer) { + const std::string prefix = "transformer.h." + std::to_string(layer); + read_tensor_by_state_key(prefix + ".ln_1.weight"); + const auto c_attn_name = prefix + ".attn.c_attn.weight"; + const int64_t head_dim = config.n_embd / config.n_head; + const int64_t q_rows = config.n_head * head_dim; + const int64_t kv_rows = config.n_kv_head * head_dim; + read_projection_into_packed_qkv(c_attn_name, 0, q_rows, c_attn_name + ".q_proj"); + read_projection_into_packed_qkv(c_attn_name, q_rows, kv_rows, c_attn_name + ".k_proj"); + read_projection_into_packed_qkv(c_attn_name, q_rows + kv_rows, kv_rows, c_attn_name + ".v_proj"); + read_tensor_by_state_key(prefix + ".attn.c_proj.weight"); + read_tensor_by_state_key(prefix + ".ln_2.weight"); + read_tensor_by_state_key(prefix + ".mlp.router.weight"); + for (int64_t expert = 0; expert < moe_config.num_experts; ++expert) { + const std::string expert_prefix = prefix + ".mlp.experts.expert_" + std::to_string(expert); + read_tensor_by_state_key(expert_prefix + ".c_fc2.weight"); // Mixtral w1/gate_proj + read_tensor_by_state_key(expert_prefix + ".c_fc.weight"); // Mixtral w3/up_proj + read_tensor_by_state_key(expert_prefix + ".c_proj.weight"); // Mixtral w2/down_proj + } + } + read_tensor_by_state_key("transformer.ln_f.weight"); + read_tensor_by_state_key("lm_head.weight"); + + CHECK_EQ(ifs.peek(), std::ifstream::traits_type::eof()) << "Unexpected trailing bytes in tiny Mixtral LLMC file"; + return model; +} + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/checkpoint_loader.h b/example/tiny_mixtral/checkpoint_loader.h new file mode 100644 index 00000000..738538ad --- /dev/null +++ b/example/tiny_mixtral/checkpoint_loader.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace infini_train::nn { +class TransformerModel; +} // namespace infini_train::nn + +namespace tiny_mixtral { + +infini_train::nn::TransformerConfig ConfigFromLLMC(const std::string &filepath); + +void CheckLLMCConfig(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config); + +std::shared_ptr +LoadFromLLMC(const std::string &filepath, const infini_train::nn::TransformerConfig &expected_config); + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/config.h b/example/tiny_mixtral/config.h new file mode 100644 index 00000000..0d7096d4 --- /dev/null +++ b/example/tiny_mixtral/config.h @@ -0,0 +1,76 @@ +#pragma once + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/transformer/transformer_config.h" + +namespace nn = infini_train::nn; + +namespace tiny_mixtral { + +inline nn::TransformerConfig TinyMixtralConfig() { + nn::TransformerConfig config; + config.block_size = 32768; // Same as Mixtral/Megatron --max-position-embeddings. + config.vocab_size = 128256; // Validation data uses LLaMA3 token ids; real Mixtral uses 32000. + config.original_vocab_size = 128256; + config.n_layer = 2; // Tiny scale; Megatron --num-layers 32. + config.n_head = 4; // Tiny scale; preserves the Megatron 4:1 GQA ratio. + config.n_kv_head = 1; // Tiny scale; Megatron --num-query-groups 8. + config.n_embd = 32; // Tiny scale; Megatron --hidden-size 4096. + config.attention_type = nn::AttentionType::kRoPE; + config.activation_type = nn::MLPType::kSwiGLU; + config.ffn_type = nn::FFNType::kMoE; + config.norm_type = nn::NormType::kRMSNorm; + config.add_bias_linear = false; + config.add_bias_lm_head = false; + config.tie_weights = false; + config.ffn_expansion_ratio = 3.5f; + config.norm_eps = 1e-5f; + config.rope_theta = 1000000.0f; + config.use_scaled_rope = false; + + nn::MoEConfig moe_config; + moe_config.num_experts = 8; + moe_config.expert_parallel_size = 1; // Single-rank validation scale. + moe_config.router_topk = 2; + moe_config.moe_ffn_hidden_size = 112; // Tiny scale; Megatron --ffn-hidden-size 14336. + moe_config.token_dispatcher_type = nn::MoEConfig::TokenDispatcherType::kAllGather; // Single-rank validation path. + moe_config.expert_impl = nn::MoEConfig::ExpertImpl::kSequential; // Local correctness path. + config.moe_config = moe_config; + return config; +} + +inline void SanitizeTinyMixtralConfig(const nn::TransformerConfig &c) { + CHECK_GT(c.block_size, 0); + CHECK_GT(c.vocab_size, 0); + CHECK_GE(c.vocab_size, c.original_vocab_size); + CHECK_GT(c.n_layer, 0); + CHECK_GT(c.n_head, 0); + CHECK_GT(c.n_kv_head, 0); + CHECK_LE(c.n_kv_head, c.n_head); + CHECK_EQ(c.n_head % c.n_kv_head, 0) << "n_head must be divisible by n_kv_head for GQA"; + CHECK_GT(c.n_embd, 0); + CHECK_EQ(c.n_embd % c.n_head, 0) << "n_embd must be divisible by n_head"; + CHECK(c.attention_type == nn::AttentionType::kRoPE) << "tiny Mixtral requires RoPE attention"; + CHECK(c.activation_type == nn::MLPType::kSwiGLU) << "tiny Mixtral requires SwiGLU activation"; + CHECK(c.ffn_type == nn::FFNType::kMoE) << "tiny Mixtral requires MoE FFN"; + CHECK(c.norm_type == nn::NormType::kRMSNorm) << "tiny Mixtral requires RMSNorm"; + CHECK(!c.add_bias_linear) << "tiny Mixtral has no bias in linear layers"; + CHECK(!c.add_bias_lm_head) << "tiny Mixtral has no bias in lm_head"; + CHECK(!c.tie_weights) << "tiny Mixtral does not tie embedding and lm_head weights"; + CHECK(!c.use_scaled_rope) << "tiny Mixtral precision validation keeps scaled RoPE disabled"; + CHECK(c.moe_config.has_value()) << "tiny Mixtral requires MoE config"; + + const auto &moe = c.moe_config.value(); + CHECK_GT(moe.num_experts, 0); + CHECK_EQ(moe.expert_parallel_size, 1) << "tiny Mixtral single-rank validation expects EP=1"; + CHECK_GT(moe.router_topk, 0); + CHECK_LE(moe.router_topk, moe.num_experts); + CHECK_GT(moe.moe_ffn_hidden_size, 0); + CHECK(moe.token_dispatcher_type == nn::MoEConfig::TokenDispatcherType::kAllGather) + << "tiny Mixtral uses the Megatron-style AllGather dispatcher"; + CHECK(moe.expert_impl == nn::MoEConfig::ExpertImpl::kSequential) + << "tiny Mixtral validation uses SequentialMLP experts"; +} + +} // namespace tiny_mixtral diff --git a/example/tiny_mixtral/main.cc b/example/tiny_mixtral/main.cc new file mode 100644 index 00000000..dc2b5136 --- /dev/null +++ b/example/tiny_mixtral/main.cc @@ -0,0 +1,136 @@ +#include +#include +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" + +#include "example/common/tiny_shakespeare_dataset.h" +#include "example/tiny_mixtral/checkpoint_loader.h" +#include "example/tiny_mixtral/config.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dataloader.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/modules/loss.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +DEFINE_string(input_bin, "", "input .bin to train on"); +DEFINE_uint32(batch_size, 4, "batch size"); +DEFINE_uint32(sequence_length, 64, "sequence length"); +DEFINE_uint32(num_iteration, 10, "number of training iterations"); +DEFINE_double(learning_rate, 1e-4, "SGD learning rate"); +DEFINE_string(llmc_filepath, "", + "optional PyTorch-generated tiny Mixtral LLMC model file path to load before training"); +DEFINE_string(device, "cpu", "Training device: cpu or cuda."); +DEFINE_uint32(log_interval, 1, "Print train loss every N steps. 0 disables step loss logging."); +DEFINE_bool(print_timing, false, "Print training-loop elapsed time and token throughput."); + +namespace { + +using infini_train::Device; +using infini_train::Tensor; + +void ValidateRuntimeFlags(const infini_train::nn::TransformerConfig &config) { + CHECK(!FLAGS_input_bin.empty()) << "tiny Mixtral training requires --input_bin"; + CHECK_GT(FLAGS_batch_size, 0); + CHECK_GT(FLAGS_sequence_length, 0); + CHECK_LE(FLAGS_sequence_length, config.block_size) << "sequence_length must be <= model max positions (block_size)"; +} + +} // namespace + +int main(int argc, char *argv[]) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + + infini_train::nn::parallel::global::InitAllEnv( + /*nthread_per_process=*/1, + /*tensor_parallel_size=*/1, + /*sequence_parallel_enabled=*/false, + /*pipeline_parallel_size=*/1, + /*virtual_pipeline_parallel_size=*/1); + + infini_train::nn::TransformerConfig model_config = tiny_mixtral::TinyMixtralConfig(); + tiny_mixtral::SanitizeTinyMixtralConfig(model_config); + std::shared_ptr model = nullptr; + if (!FLAGS_llmc_filepath.empty()) { + model = tiny_mixtral::LoadFromLLMC(FLAGS_llmc_filepath, model_config); + } else { + model = std::make_shared(model_config); + } + ValidateRuntimeFlags(model_config); + + Device train_device; + if (FLAGS_device == "cuda") { + train_device = Device(Device::DeviceType::kCUDA, 0); + model->To(train_device); + } else { + CHECK_EQ(FLAGS_device, "cpu") << "Unsupported training device: " << FLAGS_device; + train_device = Device(); + } + + infini_train::DistributedDataLoader train_loader( + std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size, + /*ddp_rank=*/0, /*ddp_world_size=*/1); + auto train_iter = train_loader.begin(); + + auto loss_fn = std::make_shared(); + auto optimizer + = infini_train::optimizers::SGD::Create(static_cast(FLAGS_learning_rate))(model->Parameters()); + + auto device_impl = infini_train::core::GetDeviceGuardImpl(train_device.type()); + std::vector step_duration_ms; + step_duration_ms.reserve(FLAGS_num_iteration); + const double tokens_per_step = static_cast(FLAGS_batch_size) * FLAGS_sequence_length; + for (uint32_t step = 0; step < FLAGS_num_iteration; ++step) { + device_impl->SynchronizeDevice(train_device); + const auto step_start_time = std::chrono::steady_clock::now(); + + optimizer->ZeroGrad(); + if (train_iter == train_loader.end()) { + train_iter = train_loader.begin(); + } + auto [x_cpu, y_cpu] = *train_iter; + ++train_iter; + auto x = std::make_shared(x_cpu->To(train_device)); + auto y = std::make_shared(y_cpu->To(train_device)); + auto logits = (*model)({x})[0]; + auto loss = (*loss_fn)({logits, y})[0]; + loss->Backward(); + optimizer->Step(); + + device_impl->SynchronizeDevice(train_device); + const auto step_end_time = std::chrono::steady_clock::now(); + const double duration_ms = std::chrono::duration(step_end_time - step_start_time).count(); + step_duration_ms.push_back(duration_ms); + + if (FLAGS_log_interval > 0 && ((step + 1) % FLAGS_log_interval == 0 || step + 1 == FLAGS_num_iteration)) { + auto loss_cpu = loss->To(Device()); + const float lossf = static_cast(loss_cpu.DataPtr())[0]; + std::cout << std::format( + "step {:4d}/{} | train loss {:.6f} | norm -1.0000 | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s)", step + 1, + FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_ms, tokens_per_step / (duration_ms / 1e3)) + << std::endl; + } + } + if (!step_duration_ms.empty()) { + double duration_sum_ms = 0.0; + for (size_t idx = step_duration_ms.size() > 1 ? 1 : 0; idx < step_duration_ms.size(); ++idx) { + duration_sum_ms += step_duration_ms[idx]; + } + const size_t averaged_steps + = step_duration_ms.size() > 1 ? step_duration_ms.size() - 1 : step_duration_ms.size(); + std::cout << std::format("final {} iters avg: {:.3f}ms", averaged_steps, duration_sum_ms / averaged_steps) + << std::endl; + } + + gflags::ShutDownCommandLineFlags(); + google::ShutdownGoogleLogging(); + return 0; +} diff --git a/infini_train/include/nn/modules/transformer/transformer_config.h b/infini_train/include/nn/modules/transformer/transformer_config.h index 8c440d16..75d1e473 100644 --- a/infini_train/include/nn/modules/transformer/transformer_config.h +++ b/infini_train/include/nn/modules/transformer/transformer_config.h @@ -36,7 +36,7 @@ struct MoEConfig { kSigmoid, }; - enum class DispatcherType { + enum class TokenDispatcherType { kAllGather, // Megatron-style AllGather dispatcher. Degenerates to local dispatch when TP=EP=1. kAllToAll // Megatron-style AllToAll dispatcher for expert parallel MoE. }; @@ -55,7 +55,7 @@ struct MoEConfig { std::optional expert_capacity_factor = std::nullopt; bool pad_expert_input_to_capacity = false; int64_t moe_ffn_hidden_size = 0; - DispatcherType dispatcher_type = DispatcherType::kAllGather; + TokenDispatcherType token_dispatcher_type = TokenDispatcherType::kAllGather; ExpertImpl expert_impl = ExpertImpl::kSequential; }; diff --git a/infini_train/src/nn/modules/transformer/moe/experts.cc b/infini_train/src/nn/modules/transformer/moe/experts.cc index 7566c48f..fa8681da 100644 --- a/infini_train/src/nn/modules/transformer/moe/experts.cc +++ b/infini_train/src/nn/modules/transformer/moe/experts.cc @@ -19,7 +19,7 @@ SequentialMLP::SequentialMLP(const TransformerConfig &config) : CloneableModule( CHECK(moe_config.expert_impl == MoEConfig::ExpertImpl::kSequential); CHECK_EQ(moe_config.expert_parallel_size, 1) << "Current InfiniTrain MoE implementation supports expert_parallel_size=1 only"; - CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + CHECK(moe_config.token_dispatcher_type == MoEConfig::TokenDispatcherType::kAllGather) << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; num_local_experts_ = moe_config.num_experts; diff --git a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc index 6add37ef..1e15fe81 100644 --- a/infini_train/src/nn/modules/transformer/moe/moe_layer.cc +++ b/infini_train/src/nn/modules/transformer/moe/moe_layer.cc @@ -15,7 +15,7 @@ namespace infini_train::nn::moe { MoELayer::MoELayer(const TransformerConfig &config) : CloneableModule(kType), config_(config) { const auto &moe_config = RequireMoEConfig(config_); CHECK(config_.ffn_type == FFNType::kMoE); - CHECK(moe_config.dispatcher_type == MoEConfig::DispatcherType::kAllGather) + CHECK(moe_config.token_dispatcher_type == MoEConfig::TokenDispatcherType::kAllGather) << "Current InfiniTrain MoE implementation supports AllGather dispatcher only"; modules_[kRouterLayerName] = std::make_shared(config_); From 181fc36a041a408a0d7dc29800f15bdc2321c1b6 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Tue, 2 Jun 2026 10:00:04 +0000 Subject: [PATCH 11/11] test: integrate tiny_mixtral into automated test pipeline --- scripts/run_models_and_profile.bash | 58 ++++++++++++++++++- scripts/test_config.json | 20 +++++-- tests/autograd/test_autograd.cc | 2 +- .../test_autograd_transform_forward.cc | 1 - 4 files changed, 71 insertions(+), 10 deletions(-) diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index e3c67293..94b065d1 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -219,6 +219,54 @@ args_string_for_test() { ' "$CONFIG_FILE" | paste -sd' ' - } +args_string_for_moe_test() { + local test_idx="$1" + jq -r --argjson t "$test_idx" ' + (.moe_tests // [])[ $t ].args + | to_entries[] + | "--\(.key)=\(.value|tostring)" + ' "$CONFIG_FILE" | paste -sd' ' - +} + +run_moe_validation() { + local profile_flag="$1" + local tag="moe" + + if [[ ${#SELECTED_TAGS[@]} -gt 0 && -z "${SELECTED_TAGS[$tag]:-}" ]]; then + return + fi + + if [[ ! -f "$MIXTRAL_INPUT_BIN" ]]; then + echo "Error: missing MIXTRAL_INPUT_BIN: $MIXTRAL_INPUT_BIN" >&2 + exit 1 + fi + if [[ ! -f "$MIXTRAL_LLMC_FILEPATH" ]]; then + echo "Error: missing MIXTRAL_LLMC_FILEPATH: $MIXTRAL_LLMC_FILEPATH" >&2 + exit 1 + fi + + local num_moe_tests + num_moe_tests=$(jq '.moe_tests // [] | length' "$CONFIG_FILE") + if [[ "$num_moe_tests" -eq 0 ]]; then + echo "Error: MoE validation requires at least one entry in moe_tests." >&2 + exit 1 + fi + + echo -e "\033[1;36m[TEST GROUP] tag=${tag}, cases=${num_moe_tests}\033[0m" + for ((mi=0; mi