diff --git a/infini_train/include/autograd/comm.h b/infini_train/include/autograd/comm.h index ec3cfe4a..c67372ee 100644 --- a/infini_train/include/autograd/comm.h +++ b/infini_train/include/autograd/comm.h @@ -15,7 +15,7 @@ class ProcessGroup; } // namespace nn::parallel } // namespace infini_train -namespace infini_train::autograd { +namespace infini_train::autograd::comm { class Scatter : public autograd::Function { public: static constexpr char kType[] = "ScatterFunction"; @@ -99,4 +99,4 @@ class ReduceAddCoalesced : public autograd::Function { std::vector target_gpus_; int64_t num_inputs_ = 0; }; -} // namespace infini_train::autograd +} // namespace infini_train::autograd::comm diff --git a/infini_train/include/autograd/gather.h b/infini_train/include/autograd/gather.h new file mode 100644 index 00000000..0fb44c51 --- /dev/null +++ b/infini_train/include/autograd/gather.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Gather : public Function { +public: + static constexpr char kType[] = "GatherFunction"; + + Gather(int64_t dim = 0) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t dim_ = 0; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/misc.h b/infini_train/include/autograd/misc.h deleted file mode 100644 index ccfca22d..00000000 --- a/infini_train/include/autograd/misc.h +++ /dev/null @@ -1,113 +0,0 @@ -#pragma once - -#include -#include - -#include "infini_train/include/autograd/function.h" - -namespace infini_train { -class Tensor; -} - -namespace infini_train::autograd { -class Split : public Function { -public: - static constexpr char kType[] = "SplitFunction"; - - Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t split_size_ = 0; - const int dim_ = 0; - std::vector input_dims_; -}; - -// FIXME(zbl): This function aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency -class IndexGather : public Function { -public: - static constexpr char kType[] = "IndexGatherFunction"; - - IndexGather(int64_t dim = 0) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t dim_ = 0; - std::vector input_dims_; -}; - -class NoOp : public Function { -public: - static constexpr char kType[] = "NoOpFunction"; - - explicit NoOp(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const std::vector output_dims_; - std::vector input_dims_; -}; - -class Slice : public Function { -public: - static constexpr char kType[] = "SliceFunction"; - - Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) - : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const std::vector starts_; - const std::vector ends_; - const std::vector steps_; -}; - -class Stack : public Function { -public: - static constexpr char kType[] = "StackFunction"; - - Stack(int64_t dim) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - int64_t dim_ = 0; - std::vector input_dims_; -}; - -class Concat : public Function { -public: - static constexpr char kType[] = "ConcatFunction"; - - Concat(int64_t dim) : Function(kType), dim_(dim) {} - - std::vector> Forward(const std::vector> &input_tensors) override; - void SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) override; - std::vector> Backward(const std::vector> &grad_outputs) override; - -private: - const int64_t dim_ = 0; - std::vector> input_dims_list_; -}; -} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/no_op.h b/infini_train/include/autograd/no_op.h new file mode 100644 index 00000000..a097393d --- /dev/null +++ b/infini_train/include/autograd/no_op.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class NoOp : public Function { +public: + static constexpr char kType[] = "NoOpFunction"; + + explicit NoOp(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const std::vector output_dims_; + std::vector input_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/scatter.h b/infini_train/include/autograd/scatter.h new file mode 100644 index 00000000..3d6f830a --- /dev/null +++ b/infini_train/include/autograd/scatter.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +#include "infini_train/include/autograd/function.h" + +namespace infini_train { +class Tensor; +} + +namespace infini_train::autograd { + +class Scatter : public Function { +public: + static constexpr char kType[] = "ScatterFunction"; + + explicit Scatter(const std::vector &output_dims) : Function(kType), output_dims_(output_dims) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + std::vector output_dims_; +}; + +} // namespace infini_train::autograd diff --git a/infini_train/include/autograd/transform.h b/infini_train/include/autograd/transform.h index 92ce71ea..88b7d56e 100644 --- a/infini_train/include/autograd/transform.h +++ b/infini_train/include/autograd/transform.h @@ -78,4 +78,70 @@ class RepeatInterleave : public Function { std::vector input_dims_; }; +class Split : public Function { +public: + static constexpr char kType[] = "SplitFunction"; + + Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t split_size_ = 0; + const int dim_ = 0; + std::vector input_dims_; +}; + +class Stack : public Function { +public: + static constexpr char kType[] = "StackFunction"; + + Stack(int64_t dim) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + int64_t dim_ = 0; + std::vector input_dims_; +}; + +class Concat : public Function { +public: + static constexpr char kType[] = "ConcatFunction"; + + Concat(int64_t dim) : Function(kType), dim_(dim) {} + + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const int64_t dim_ = 0; + std::vector> input_dims_list_; +}; + +class Slice : public Function { +public: + static constexpr char kType[] = "SliceFunction"; + + Slice(const std::vector &starts, const std::vector &ends, const std::vector &steps) + : Function(kType), starts_(starts), ends_(ends), steps_(steps) {} + std::vector> Forward(const std::vector> &input_tensors) override; + void SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) override; + std::vector> Backward(const std::vector> &grad_outputs) override; + +private: + const std::vector starts_; + const std::vector ends_; + const std::vector steps_; +}; + } // namespace infini_train::autograd diff --git a/infini_train/src/autograd/comm.cc b/infini_train/src/autograd/comm.cc index d524088a..325422b3 100644 --- a/infini_train/src/autograd/comm.cc +++ b/infini_train/src/autograd/comm.cc @@ -8,7 +8,7 @@ #include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/tensor.h" -namespace infini_train::autograd { +namespace infini_train::autograd::comm { Scatter::Scatter(const std::vector &target_gpus, int64_t dim, const infini_train::nn::parallel::ProcessGroup *pg) @@ -122,4 +122,4 @@ std::vector> ReduceAddCoalesced::Backward(const std::vector> &grad_outputs) { return std::make_shared(target_gpus_)->Apply(grad_outputs); } -} // namespace infini_train::autograd +} // namespace infini_train::autograd::comm diff --git a/infini_train/src/autograd/gather.cc b/infini_train/src/autograd/gather.cc new file mode 100644 index 00000000..a30cb013 --- /dev/null +++ b/infini_train/src/autograd/gather.cc @@ -0,0 +1,37 @@ +#include "infini_train/include/autograd/gather.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { +std::vector> Gather::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &input = input_tensors[0]; + const auto &index = input_tensors[1]; + + auto device = input->GetDevice().type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "GatherForward"}); + return {kernel.Call>(input, index, dim_)}; +} + +void Gather::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + const auto &index = input_tensors[1]; + input_dims_ = input->Dims(); + saved_tensors_ = {index}; +} + +std::vector> Gather::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &index = saved_tensors_[0]; + + auto device = grad_outputs[0]->GetDevice(); + auto kernel = Dispatcher::Instance().GetKernel({device.type(), "GatherBackward"}); + return {kernel.Call>(grad_output, index, dim_, input_dims_), nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/misc.cc b/infini_train/src/autograd/misc.cc deleted file mode 100644 index 601258eb..00000000 --- a/infini_train/src/autograd/misc.cc +++ /dev/null @@ -1,147 +0,0 @@ -#include "infini_train/include/autograd/misc.h" - -#include "glog/logging.h" - -#include "infini_train/include/dispatcher.h" -#include "infini_train/include/tensor.h" - -namespace infini_train::autograd { -std::vector> Split::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, - split_size_, dim_)}; -} - -void Split::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> Split::Backward(const std::vector> &grad_outputs) { - auto device = grad_outputs[0]->GetDevice(); - return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, - split_size_, dim_, grad_outputs)}; -} - -std::vector> IndexGather::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 2); - const auto &input = input_tensors[0]; - const auto &index = input_tensors[1]; - - auto device = input->GetDevice().type(); - auto kernel = Dispatcher::Instance().GetKernel({device, "IndexGatherForward"}); - return {kernel.Call>(input, index, dim_)}; -} - -void IndexGather::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - const auto &index = input_tensors[1]; - input_dims_ = input->Dims(); - saved_tensors_ = {index}; -} - -std::vector> IndexGather::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - const auto &index = saved_tensors_[0]; - - auto device = grad_outputs[0]->GetDevice(); - auto kernel = Dispatcher::Instance().GetKernel({device.type(), "IndexGatherBackward"}); - return {kernel.Call>(grad_output, index, dim_, input_dims_)}; -} - -std::vector> NoOp::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; -} - -void NoOp::SetupContext(const std::vector> &input_tensors, - const std::vector> &output_tensors) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> NoOp::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(grad_outputs.size(), 1); - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; -} - -std::vector> Slice::Forward(const std::vector> &input_tensors) { - CHECK_EQ(input_tensors.size(), 1); - const auto &input = input_tensors[0]; - - auto device = input->GetDevice().type(); - return { - Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; -} - -void Slice::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - // FIXME(dcj): only input's dim need to be saved - const auto &input = input_tensors[0]; - saved_tensors_ = {input}; -} - -std::vector> Slice::Backward(const std::vector> &grad_outputs) { - CHECK_EQ(saved_tensors_.size(), 1); - const auto &input = saved_tensors_[0]; - const auto &grad_output = grad_outputs[0]; - - auto device = input->GetDevice().type(); - return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, - ends_, steps_)}; -} - -std::vector> Stack::Forward(const std::vector> &input_tensors) { - CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice().type(); - - return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; -} - -void Stack::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - const auto &input = input_tensors[0]; - input_dims_ = input->Dims(); -} - -std::vector> Stack::Backward(const std::vector> &grad_outputs) { - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, - dim_, grad_output)}; -} - -std::vector> Concat::Forward(const std::vector> &input_tensors) { - CHECK_GE(input_tensors.size(), 2); - const auto device = input_tensors[0]->GetDevice().type(); - - auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); - return {kernel.Call>(input_tensors, dim_)}; -} - -void Concat::SetupContext(const std::vector> &input_tensors, - const std::vector> &) { - for (auto input : input_tensors) { input_dims_list_.push_back(input->Dims()); } -} - -std::vector> Concat::Backward(const std::vector> &grad_outputs) { - const auto &grad_output = grad_outputs[0]; - - auto device = grad_output->GetDevice().type(); - auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); - return kernel.Call>>(grad_output, input_dims_list_, dim_); -} -} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/no_op.cc b/infini_train/src/autograd/no_op.cc new file mode 100644 index 00000000..b4247dec --- /dev/null +++ b/infini_train/src/autograd/no_op.cc @@ -0,0 +1,31 @@ +#include "infini_train/include/autograd/no_op.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { +std::vector> NoOp::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "NoOpForward"}, input, output_dims_)}; +} + +void NoOp::SetupContext(const std::vector> &input_tensors, + const std::vector> &output_tensors) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> NoOp::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "NoOpBackward"}, input_dims_, grad_output)}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/scatter.cc b/infini_train/src/autograd/scatter.cc new file mode 100644 index 00000000..472fd543 --- /dev/null +++ b/infini_train/src/autograd/scatter.cc @@ -0,0 +1,34 @@ +#include "infini_train/include/autograd/scatter.h" + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::autograd { + +std::vector> Scatter::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 2); + const auto &values = input_tensors[0]; + const auto &indices = input_tensors[1]; + auto device = values->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "ScatterForward"}, values, indices, + output_dims_)}; +} + +void Scatter::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + saved_tensors_ = {input_tensors[1]}; +} + +std::vector> Scatter::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(grad_outputs.size(), 1); + const auto &grad_output = grad_outputs[0]; + const auto &indices = saved_tensors_[0]; + auto device = grad_output->GetDevice().type(); + auto grad_values + = Dispatcher::Instance().Call>({device, "ScatterBackward"}, grad_output, indices); + return {grad_values, nullptr}; +} + +} // namespace infini_train::autograd diff --git a/infini_train/src/autograd/transform.cc b/infini_train/src/autograd/transform.cc index 4fae05bb..e38d5616 100644 --- a/infini_train/src/autograd/transform.cc +++ b/infini_train/src/autograd/transform.cc @@ -89,4 +89,94 @@ RepeatInterleave::Backward(const std::vector> &grad_outp return {Dispatcher::Instance().Call>({device, "RepeatInterleaveBackward"}, grad_output, input_dims_, dim_)}; } + +std::vector> Split::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>>({device, "SplitForward"}, input, + split_size_, dim_)}; +} + +void Split::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> Split::Backward(const std::vector> &grad_outputs) { + auto device = grad_outputs[0]->GetDevice(); + return {Dispatcher::Instance().Call>({device.type(), "SplitBackward"}, input_dims_, + split_size_, dim_, grad_outputs)}; +} + +std::vector> Stack::Forward(const std::vector> &input_tensors) { + CHECK_GE(input_tensors.size(), 2); + const auto device = input_tensors[0]->GetDevice().type(); + + return {Dispatcher::Instance().Call>({device, "StackForward"}, input_tensors, dim_)}; +} + +void Stack::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + const auto &input = input_tensors[0]; + input_dims_ = input->Dims(); +} + +std::vector> Stack::Backward(const std::vector> &grad_outputs) { + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + return {Dispatcher::Instance().Call>>({device, "StackBackward"}, input_dims_, + dim_, grad_output)}; +} + +std::vector> Concat::Forward(const std::vector> &input_tensors) { + CHECK_GE(input_tensors.size(), 2); + const auto device = input_tensors[0]->GetDevice().type(); + + auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatForward"}); + return {kernel.Call>(input_tensors, dim_)}; +} + +void Concat::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + for (auto input : input_tensors) { input_dims_list_.push_back(input->Dims()); } +} + +std::vector> Concat::Backward(const std::vector> &grad_outputs) { + const auto &grad_output = grad_outputs[0]; + + auto device = grad_output->GetDevice().type(); + auto kernel = Dispatcher::Instance().GetKernel({device, "ConcatBackward"}); + return kernel.Call>>(grad_output, input_dims_list_, dim_); +} + +std::vector> Slice::Forward(const std::vector> &input_tensors) { + CHECK_EQ(input_tensors.size(), 1); + const auto &input = input_tensors[0]; + + auto device = input->GetDevice().type(); + return { + Dispatcher::Instance().Call>({device, "SliceForward"}, input, starts_, ends_, steps_)}; +} + +void Slice::SetupContext(const std::vector> &input_tensors, + const std::vector> &) { + // FIXME(dcj): only input's dim need to be saved + const auto &input = input_tensors[0]; + saved_tensors_ = {input}; +} + +std::vector> Slice::Backward(const std::vector> &grad_outputs) { + CHECK_EQ(saved_tensors_.size(), 1); + const auto &input = saved_tensors_[0]; + const auto &grad_output = grad_outputs[0]; + + auto device = input->GetDevice().type(); + return {Dispatcher::Instance().Call>({device, "SliceBackward"}, grad_output, input, starts_, + ends_, steps_)}; +} + } // namespace infini_train::autograd diff --git a/infini_train/src/kernels/cpu/gather.cc b/infini_train/src/kernels/cpu/gather.cc index b59fd45f..af39fc0f 100644 --- a/infini_train/src/kernels/cpu/gather.cc +++ b/infini_train/src/kernels/cpu/gather.cc @@ -8,11 +8,8 @@ #include "infini_train/include/tensor.h" namespace infini_train::kernels::cpu { -// FIXME(zbl): This kernel aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency -std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const std::shared_ptr &index, - int64_t dim) { +std::shared_ptr GatherForward(const std::shared_ptr &input, const std::shared_ptr &index, + int64_t dim) { const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -103,9 +100,8 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, return out; } -std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &index, int64_t dim, - const std::vector &input_dims) { +std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, const std::shared_ptr &index, + int64_t dim, const std::vector &input_dims) { const auto &in_dims = input_dims; const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -199,7 +195,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ #define REGISTER_CPU_GATHER_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) -REGISTER_CPU_GATHER_KERNEL(IndexGatherForward) -REGISTER_CPU_GATHER_KERNEL(IndexGatherBackward) +REGISTER_CPU_GATHER_KERNEL(GatherForward) +REGISTER_CPU_GATHER_KERNEL(GatherBackward) #undef REGISTER_CPU_GATHER_KERNEL diff --git a/infini_train/src/kernels/cpu/scatter.cc b/infini_train/src/kernels/cpu/scatter.cc new file mode 100644 index 00000000..1a9cf62e --- /dev/null +++ b/infini_train/src/kernels/cpu/scatter.cc @@ -0,0 +1,86 @@ +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::kernels::cpu { + +std::shared_ptr ScatterForward(const std::shared_ptr &values, const std::shared_ptr &indices, + const std::vector &output_dims) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterForward expects int64 indices"; + CHECK(values->Dims() == indices->Dims()); + CHECK(!output_dims.empty()); + CHECK_EQ(values->Dims().size(), output_dims.size()); + CHECK_GT(values->Dims().back(), 0); + CHECK_GT(output_dims.back(), 0); + + const int64_t topk = values->Dims().back(); + const int64_t num_experts = output_dims.back(); + const int64_t rows = values->NumElements() / topk; + size_t output_numel = 1; + for (const auto dim : output_dims) { output_numel *= static_cast(dim); } + CHECK_EQ(output_numel, static_cast(rows * num_experts)); + + auto output = std::make_shared(output_dims, values->Dtype(), values->GetDevice()); + std::memset(output->DataPtr(), 0, output->SizeInBytes()); + + const size_t elem_size = kDataTypeToSize.at(values->Dtype()); + const auto *src = static_cast(values->DataPtr()); + auto *dst = static_cast(output->DataPtr()); + const auto *idx = static_cast(indices->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t expert_idx = idx[row * topk + selected]; + CHECK_GE(expert_idx, 0); + CHECK_LT(expert_idx, num_experts); + std::memcpy(dst + (row * num_experts + expert_idx) * elem_size, src + (row * topk + selected) * elem_size, + elem_size); + } + } + + return output; +} + +std::shared_ptr ScatterBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &indices) { + CHECK(indices->Dtype() == DataType::kINT64) << "CPU ScatterBackward expects int64 indices"; + CHECK_GE(grad_output->Dims().size(), 1); + CHECK_GE(indices->Dims().size(), 1); + + const int64_t num_experts = grad_output->Dims().back(); + const int64_t topk = indices->Dims().back(); + const int64_t rows = indices->NumElements() / topk; + CHECK_EQ(grad_output->NumElements(), static_cast(rows * num_experts)); + + auto grad_values = std::make_shared(indices->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + const size_t elem_size = kDataTypeToSize.at(grad_output->Dtype()); + const auto *src = static_cast(grad_output->DataPtr()); + auto *dst = static_cast(grad_values->DataPtr()); + const auto *idx = static_cast(indices->DataPtr()); + for (int64_t row = 0; row < rows; ++row) { + for (int64_t selected = 0; selected < topk; ++selected) { + const int64_t expert_idx = idx[row * topk + selected]; + CHECK_GE(expert_idx, 0); + CHECK_LT(expert_idx, num_experts); + std::memcpy(dst + (row * topk + selected) * elem_size, src + (row * num_experts + expert_idx) * elem_size, + elem_size); + } + } + + return grad_values; +} + +} // namespace infini_train::kernels::cpu + +#define REGISTER_CPU_SCATTER_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCPU, kernel_name, infini_train::kernels::cpu::kernel_name) + +REGISTER_CPU_SCATTER_KERNEL(ScatterForward) +REGISTER_CPU_SCATTER_KERNEL(ScatterBackward) + +#undef REGISTER_CPU_SCATTER_KERNEL diff --git a/infini_train/src/kernels/cuda/comm.cu b/infini_train/src/kernels/cuda/comm.cu index b4bdafd8..6300ffdb 100644 --- a/infini_train/src/kernels/cuda/comm.cu +++ b/infini_train/src/kernels/cuda/comm.cu @@ -9,7 +9,7 @@ #include "infini_train/include/nn/functional.h" #include "infini_train/include/tensor.h" -namespace infini_train::kernels::cuda { +namespace infini_train::kernels::cuda::comm { std::vector> Broadcast(const std::vector> &input_tensors, const std::vector &devices) { @@ -69,11 +69,11 @@ std::shared_ptr Gather(const std::vector> &tenso auto view_kernel = Dispatcher::Instance().GetKernel({destination.type(), "NoOpForward"}); return view_kernel.Call>(gathered_tensor, new_dims); } -} // namespace infini_train::kernels::cuda +} // namespace infini_train::kernels::cuda::comm #define REGISTER_CUDA_COMM_KERNEL(kernel_name) \ - REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, Comm##kernel_name, \ - infini_train::kernels::cuda::kernel_name) + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, \ + infini_train::kernels::cuda::comm::kernel_name) REGISTER_CUDA_COMM_KERNEL(Broadcast) REGISTER_CUDA_COMM_KERNEL(Scatter) diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index 12d0567d..5216f28e 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -9,15 +9,11 @@ #include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" namespace infini_train::kernels::cuda { -// FIXME(zbl): This kernel aligns with torch.gather -// Currently named IndexGather to avoid conflict with communication operators -// Should be renamed to Gather later for interface consistency template -__global__ void IndexGatherForwardKernel(const T *__restrict__ input, const int64_t *__restrict__ norm_index, - T *__restrict__ output, const int64_t *__restrict__ out_dims, - const int64_t *__restrict__ in_strides, - const int64_t *__restrict__ out_strides, int num_dims, int gather_dim, - int64_t dim_size_gather, int64_t total_elements) { +__global__ void GatherForwardKernel(const T *__restrict__ input, const int64_t *__restrict__ norm_index, + T *__restrict__ output, const int64_t *__restrict__ out_dims, + const int64_t *__restrict__ in_strides, const int64_t *__restrict__ out_strides, + int num_dims, int gather_dim, int64_t dim_size_gather, int64_t total_elements) { int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; if (out_idx >= total_elements) { return; @@ -43,8 +39,8 @@ __global__ void IndexGatherForwardKernel(const T *__restrict__ input, const int6 output[out_idx] = input[in_linear]; } -std::shared_ptr IndexGatherForward(const std::shared_ptr &input, const std::shared_ptr &index, - int64_t dim) { +std::shared_ptr GatherForward(const std::shared_ptr &input, const std::shared_ptr &index, + int64_t dim) { const auto &in_dims = input->Dims(); const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -103,23 +99,22 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, core::cuda::DispatchCudaFunc( dtype, [=]() { - IndexGatherForwardKernel<<>>( + GatherForwardKernel<<>>( static_cast(input->DataPtr()), static_cast(index->DataPtr()), static_cast(out->DataPtr()), out_dims_dev, in_strides_dev, out_strides_dev, (int)num_dims, (int)dim, gather_dim_size, total_elements); }, - "CUDA IndexGatherForward"); + "CUDA GatherForward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); return out; } template -__global__ void IndexGatherBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ index, - T *__restrict__ grad_input, const int64_t *__restrict__ out_dims, - const int64_t *__restrict__ in_strides, - const int64_t *__restrict__ out_strides, int num_dims, int gather_dim, - int64_t dim_size_gather, int64_t total_elements) { +__global__ void GatherBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ index, + T *__restrict__ grad_input, const int64_t *__restrict__ out_dims, + const int64_t *__restrict__ in_strides, const int64_t *__restrict__ out_strides, + int num_dims, int gather_dim, int64_t dim_size_gather, int64_t total_elements) { int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; if (out_idx >= total_elements) { return; @@ -149,9 +144,8 @@ __global__ void IndexGatherBackwardKernel(const T *__restrict__ grad_output, con atomicAdd(&grad_input[in_linear], grad_output[out_idx]); } -std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_output, - const std::shared_ptr &index, int64_t dim, - const std::vector &input_dims) { +std::shared_ptr GatherBackward(const std::shared_ptr &grad_output, const std::shared_ptr &index, + int64_t dim, const std::vector &input_dims) { const auto &in_dims = input_dims; const auto &idx_dims = index->Dims(); CHECK_EQ(in_dims.size(), idx_dims.size()); @@ -210,12 +204,12 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ core::cuda::DispatchCudaFunc( dtype, [=]() { - IndexGatherBackwardKernel<<>>( + GatherBackwardKernel<<>>( static_cast(grad_output->DataPtr()), static_cast(index->DataPtr()), static_cast(grad_input->DataPtr()), out_dims_dev, in_strides_dev, out_strides_dev, (int)num_dims, (int)dim, gather_dim_size, total_elements); }, - "CUDA IndexGatherBackward"); + "CUDA GatherBackward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); return grad_input; @@ -226,7 +220,7 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ #define REGISTER_CUDA_GATHER_KERNEL(kernel_name) \ REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) -REGISTER_CUDA_GATHER_KERNEL(IndexGatherForward) -REGISTER_CUDA_GATHER_KERNEL(IndexGatherBackward) +REGISTER_CUDA_GATHER_KERNEL(GatherForward) +REGISTER_CUDA_GATHER_KERNEL(GatherBackward) #undef REGISTER_CUDA_GATHER_KERNEL diff --git a/infini_train/src/kernels/cuda/scatter.cu b/infini_train/src/kernels/cuda/scatter.cu new file mode 100644 index 00000000..c58658ff --- /dev/null +++ b/infini_train/src/kernels/cuda/scatter.cu @@ -0,0 +1,115 @@ +#include "glog/logging.h" + +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/core/runtime/device_guard.h" +#include "infini_train/include/dispatcher.h" +#include "infini_train/include/tensor.h" + +#include "infini_train/src/core/runtime/cuda/cuda_dispatch.h" +#include "infini_train/src/core/runtime/cuda/cuda_runtime_common.h" + +namespace infini_train::kernels::cuda { + +template +__global__ void ScatterForwardKernel(const T *__restrict__ values, const int64_t *__restrict__ indices, + T *__restrict__ output, int64_t rows, int64_t topk, int64_t num_experts) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total = rows * topk; + if (idx >= total) { + return; + } + + const int64_t row = idx / topk; + const int64_t expert_idx = indices[idx]; + output[row * num_experts + expert_idx] = values[idx]; +} + +std::shared_ptr ScatterForward(const std::shared_ptr &values, const std::shared_ptr &indices, + const std::vector &output_dims) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA ScatterForward expects int64 indices"; + CHECK(values->Dims() == indices->Dims()); + CHECK(!output_dims.empty()); + CHECK_EQ(values->Dims().size(), output_dims.size()); + CHECK_GT(values->Dims().back(), 0); + CHECK_GT(output_dims.back(), 0); + + const int64_t topk = values->Dims().back(); + const int64_t num_experts = output_dims.back(); + CHECK_GT(num_experts, 0); + const int64_t rows = values->NumElements() / topk; + const int64_t output_numel = std::accumulate(output_dims.begin(), output_dims.end(), 1, std::multiplies()); + CHECK_EQ(output_numel, static_cast(rows * num_experts)); + + auto output = std::make_shared(output_dims, values->Dtype(), values->GetDevice()); + + auto device = values->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + CUDA_CHECK(cudaMemsetAsync(output->DataPtr(), 0, output->SizeInBytes(), stream)); + const int threads = 256; + const int blocks = static_cast(((rows * topk) + threads - 1) / threads); + core::cuda::DispatchCudaFunc( + values->Dtype(), + [=]() { + ScatterForwardKernel<<>>( + static_cast(values->DataPtr()), static_cast(indices->DataPtr()), + static_cast(output->DataPtr()), rows, topk, num_experts); + }, + "CUDA ScatterForward"); + return output; +} + +template +__global__ void ScatterBackwardKernel(const T *__restrict__ grad_output, const int64_t *__restrict__ indices, + T *__restrict__ grad_values, int64_t rows, int64_t topk, int64_t num_experts) { + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t total = rows * topk; + if (idx >= total) { + return; + } + const int64_t row = idx / topk; + const int64_t expert_idx = indices[idx]; + grad_values[idx] = grad_output[row * num_experts + expert_idx]; +} + +std::shared_ptr ScatterBackward(const std::shared_ptr &grad_output, + const std::shared_ptr &indices) { + CHECK(indices->Dtype() == DataType::kINT64) << "CUDA ScatterBackward expects int64 indices"; + CHECK_GE(grad_output->Dims().size(), 1); + CHECK_GE(indices->Dims().size(), 1); + const int64_t num_experts = grad_output->Dims().back(); + const int64_t topk = indices->Dims().back(); + const int64_t rows = indices->NumElements() / topk; + CHECK_EQ(grad_output->NumElements(), static_cast(rows * num_experts)); + + auto grad_values = std::make_shared(indices->Dims(), grad_output->Dtype(), grad_output->GetDevice()); + + auto device = grad_output->GetDevice(); + const auto &stream = dynamic_cast( + infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) + ->cuda_stream(); + const int threads = 256; + const int blocks = static_cast(((rows * topk) + threads - 1) / threads); + + core::cuda::DispatchCudaFunc( + grad_output->Dtype(), + [=]() { + ScatterBackwardKernel<<>>( + static_cast(grad_output->DataPtr()), static_cast(indices->DataPtr()), + static_cast(grad_values->DataPtr()), rows, topk, num_experts); + }, + "CUDA ScatterBackward"); + + return grad_values; +} + +} // namespace infini_train::kernels::cuda + +#define REGISTER_CUDA_SCATTER_KERNEL(kernel_name) \ + REGISTER_KERNEL(infini_train::Device::DeviceType::kCUDA, kernel_name, infini_train::kernels::cuda::kernel_name) + +REGISTER_CUDA_SCATTER_KERNEL(ScatterForward) +REGISTER_CUDA_SCATTER_KERNEL(ScatterBackward) + +#undef REGISTER_CUDA_SCATTER_KERNEL diff --git a/infini_train/src/nn/functional.cc b/infini_train/src/nn/functional.cc index b02f185a..c33e2368 100644 --- a/infini_train/src/nn/functional.cc +++ b/infini_train/src/nn/functional.cc @@ -6,7 +6,6 @@ #include "infini_train/include/autograd/activations.h" #include "infini_train/include/autograd/elementwise.h" -#include "infini_train/include/autograd/misc.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" #include "infini_train/include/autograd/transform.h" diff --git a/infini_train/src/nn/parallel/parallel_functional.cc b/infini_train/src/nn/parallel/parallel_functional.cc index 31db11ec..ffd218d7 100644 --- a/infini_train/src/nn/parallel/parallel_functional.cc +++ b/infini_train/src/nn/parallel/parallel_functional.cc @@ -44,7 +44,7 @@ std::vector>> Scatter(const std::vector &devices, int dim) { std::vector>> output_tensors; for (const auto &tensor : input_tensors) { - output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); + output_tensors.emplace_back(std::make_shared(devices, dim)->Apply({tensor})); } std::vector>> transposed_output_tensors; transposed_output_tensors.resize(devices.size()); @@ -59,7 +59,7 @@ std::vector> Gather(const std::vector> gather_tensors; for (const auto &tensor : tensors) { gather_tensors.push_back(tensor[0]); } - return std::make_shared(target_device, dim)->Apply(gather_tensors); + return std::make_shared(target_device, dim)->Apply(gather_tensors); } std::vector>> @@ -67,7 +67,7 @@ BroadcastCoalescedReshape(const std::vector> &tensors, c if (tensors.empty()) { return {}; } - auto tensor_copies = std::make_shared(devices)->Apply(tensors); + auto tensor_copies = std::make_shared(devices)->Apply(tensors); std::vector>> tensor_copies_reshaped(devices.size()); for (int replica_idx = 0; replica_idx < devices.size(); ++replica_idx) { tensor_copies_reshaped[replica_idx].resize(tensors.size()); diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index f7947030..44860a0f 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -13,8 +13,9 @@ #include "infini_train/include/autograd/elementwise.h" #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/autograd/gather.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/misc.h" +#include "infini_train/include/autograd/no_op.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/transform.h" @@ -356,7 +357,7 @@ std::vector> Tensor::Split(int split_size, int dim) { std::shared_ptr Tensor::Gather(int dim, const std::shared_ptr &index) { CHECK(GetDevice() == index->GetDevice()) << "index must be on the same device as input."; - return std::make_shared(dim)->Apply({shared_from_this(), index})[0]; + return std::make_shared(dim)->Apply({shared_from_this(), index})[0]; } std::shared_ptr Tensor::RepeatInterleave(int64_t repeat, int64_t dim) { diff --git a/tests/autograd/test_autograd.cc b/tests/autograd/test_autograd.cc index 5f6f2f54..1d6d129a 100644 --- a/tests/autograd/test_autograd.cc +++ b/tests/autograd/test_autograd.cc @@ -8,8 +8,8 @@ #include "infini_train/include/autograd/function.h" #include "infini_train/include/autograd/linear.h" #include "infini_train/include/autograd/matmul.h" -#include "infini_train/include/autograd/misc.h" #include "infini_train/include/autograd/normalization.h" +#include "infini_train/include/autograd/no_op.h" #include "infini_train/include/autograd/outer.h" #include "infini_train/include/autograd/reduction.h" #include "infini_train/include/autograd/softmax.h" diff --git a/tests/autograd/test_autograd_transform_forward.cc b/tests/autograd/test_autograd_transform_forward.cc index 1c156c68..680d8b0d 100644 --- a/tests/autograd/test_autograd_transform_forward.cc +++ b/tests/autograd/test_autograd_transform_forward.cc @@ -2,7 +2,6 @@ #include "gtest/gtest.h" -#include "infini_train/include/autograd/misc.h" #include "infini_train/include/autograd/transform.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/tensor.h"