Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions infini_train/include/autograd/comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ProcessGroup;
} // namespace nn::parallel
} // namespace infini_train

namespace infini_train::autograd {
namespace infini_train::autograd::comm {
class Scatter : public autograd::Function {
public:
static constexpr char kType[] = "ScatterFunction";
Expand Down Expand Up @@ -99,4 +99,4 @@ class ReduceAddCoalesced : public autograd::Function {
std::vector<Device> target_gpus_;
int64_t num_inputs_ = 0;
};
} // namespace infini_train::autograd
} // namespace infini_train::autograd::comm
30 changes: 30 additions & 0 deletions infini_train/include/autograd/gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class Gather : public Function {
public:
static constexpr char kType[] = "GatherFunction";

Gather(int64_t dim = 0) : Function(kType), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const int64_t dim_ = 0;
std::vector<int64_t> input_dims_;
};

} // namespace infini_train::autograd
113 changes: 0 additions & 113 deletions infini_train/include/autograd/misc.h

This file was deleted.

30 changes: 30 additions & 0 deletions infini_train/include/autograd/no_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class NoOp : public Function {
public:
static constexpr char kType[] = "NoOpFunction";

explicit NoOp(const std::vector<int64_t> &output_dims) : Function(kType), output_dims_(output_dims) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const std::vector<int64_t> output_dims_;
std::vector<int64_t> input_dims_;
};

} // namespace infini_train::autograd
29 changes: 29 additions & 0 deletions infini_train/include/autograd/scatter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <memory>
#include <vector>

#include "infini_train/include/autograd/function.h"

namespace infini_train {
class Tensor;
}

namespace infini_train::autograd {

class Scatter : public Function {
public:
static constexpr char kType[] = "ScatterFunction";

explicit Scatter(const std::vector<int64_t> &output_dims) : Function(kType), output_dims_(output_dims) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
std::vector<int64_t> output_dims_;
};

} // namespace infini_train::autograd
66 changes: 66 additions & 0 deletions infini_train/include/autograd/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,70 @@ class RepeatInterleave : public Function {
std::vector<int64_t> input_dims_;
};

class Split : public Function {
public:
static constexpr char kType[] = "SplitFunction";

Split(int64_t split_size, int dim = 0) : Function(kType), split_size_(split_size), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const int64_t split_size_ = 0;
const int dim_ = 0;
std::vector<int64_t> input_dims_;
};

class Stack : public Function {
public:
static constexpr char kType[] = "StackFunction";

Stack(int64_t dim) : Function(kType), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
int64_t dim_ = 0;
std::vector<int64_t> input_dims_;
};

class Concat : public Function {
public:
static constexpr char kType[] = "ConcatFunction";

Concat(int64_t dim) : Function(kType), dim_(dim) {}

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const int64_t dim_ = 0;
std::vector<std::vector<int64_t>> input_dims_list_;
};

class Slice : public Function {
public:
static constexpr char kType[] = "SliceFunction";

Slice(const std::vector<int64_t> &starts, const std::vector<int64_t> &ends, const std::vector<int64_t> &steps)
: Function(kType), starts_(starts), ends_(ends), steps_(steps) {}
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
void SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &output_tensors) override;
std::vector<std::shared_ptr<Tensor>> Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) override;

private:
const std::vector<int64_t> starts_;
const std::vector<int64_t> ends_;
const std::vector<int64_t> steps_;
};

} // namespace infini_train::autograd
4 changes: 2 additions & 2 deletions infini_train/src/autograd/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/tensor.h"

namespace infini_train::autograd {
namespace infini_train::autograd::comm {

Scatter::Scatter(const std::vector<Device> &target_gpus, int64_t dim,
const infini_train::nn::parallel::ProcessGroup *pg)
Expand Down Expand Up @@ -122,4 +122,4 @@ std::vector<std::shared_ptr<Tensor>>
ReduceAddCoalesced::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
return std::make_shared<Broadcast>(target_gpus_)->Apply(grad_outputs);
}
} // namespace infini_train::autograd
} // namespace infini_train::autograd::comm
37 changes: 37 additions & 0 deletions infini_train/src/autograd/gather.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "infini_train/include/autograd/gather.h"

#include "glog/logging.h"

#include "infini_train/include/dispatcher.h"
#include "infini_train/include/tensor.h"

namespace infini_train::autograd {
std::vector<std::shared_ptr<Tensor>> Gather::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
CHECK_EQ(input_tensors.size(), 2);
const auto &input = input_tensors[0];
const auto &index = input_tensors[1];

auto device = input->GetDevice().type();
auto kernel = Dispatcher::Instance().GetKernel({device, "GatherForward"});
return {kernel.Call<std::shared_ptr<Tensor>>(input, index, dim_)};
}

void Gather::SetupContext(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &) {
const auto &input = input_tensors[0];
const auto &index = input_tensors[1];
input_dims_ = input->Dims();
saved_tensors_ = {index};
}

std::vector<std::shared_ptr<Tensor>> Gather::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_outputs) {
CHECK_EQ(grad_outputs.size(), 1);
const auto &grad_output = grad_outputs[0];
const auto &index = saved_tensors_[0];

auto device = grad_outputs[0]->GetDevice();
auto kernel = Dispatcher::Instance().GetKernel({device.type(), "GatherBackward"});
return {kernel.Call<std::shared_ptr<Tensor>>(grad_output, index, dim_, input_dims_), nullptr};
}

} // namespace infini_train::autograd
Loading
Loading