Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions infini_train/include/core/backend_type_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ template <Device::DeviceType Dev, DataType DType> struct BackendTypeMap;
// -----------------------------------------------------------------------------
#define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \
namespace infini_train::core { \
template <> struct BackendTypeMap<DEV, DataType::kBOOL> { \
using type = bool; \
}; \
template <> struct BackendTypeMap<DEV, DataType::kUINT8> { \
using type = uint8_t; \
}; \
Expand Down
17 changes: 10 additions & 7 deletions infini_train/include/datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ struct alignas(2) BF16 {
// DataType enum and metadata tables
// -----------------------------------------------------------------------------
enum class DataType : int8_t {
kBOOL,
kUINT8,
kINT8,
kUINT16,
Expand All @@ -99,16 +100,18 @@ enum class DataType : int8_t {
};

inline const std::unordered_map<DataType, size_t> kDataTypeToSize = {
{DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2},
{DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8},
{DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8},
{DataType::kBOOL, 1}, {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2},
{DataType::kINT16, 2}, {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8},
{DataType::kINT64, 8}, {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4},
{DataType::kFLOAT64, 8},
};

inline const std::unordered_map<DataType, std::string> kDataTypeToDesc = {
{DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"},
{DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"},
{DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"},
{DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"},
{DataType::kBOOL, "bool"}, {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"},
{DataType::kUINT16, "uint16"}, {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"},
{DataType::kINT32, "int32"}, {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"},
{DataType::kBFLOAT16, "bf16"}, {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"},
{DataType::kFLOAT64, "fp64"},
};

// =============================================================================
Expand Down
5 changes: 4 additions & 1 deletion infini_train/include/dtype_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,11 @@ namespace infini_train {
#define INFINI_FLOATING_TYPES DataType::kFLOAT32, DataType::kFLOAT64
#define INFINI_REDUCED_FLOATING_TYPES DataType::kFLOAT16, DataType::kBFLOAT16
#define INFINI_ALL_FLOATING_TYPES INFINI_FLOATING_TYPES, INFINI_REDUCED_FLOATING_TYPES
#define INFINI_LOGICAL_TYPES DataType::kBOOL
#define INFINI_SIGNED_INTEGRAL_TYPES DataType::kINT8, DataType::kINT16, DataType::kINT32, DataType::kINT64
#define INFINI_UNSIGNED_INTEGRAL_TYPES DataType::kUINT8, DataType::kUINT16, DataType::kUINT32, DataType::kUINT64
#define INFINI_ALL_INTEGRAL_TYPES INFINI_SIGNED_INTEGRAL_TYPES, INFINI_UNSIGNED_INTEGRAL_TYPES
#define INFINI_ALL_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES
#define INFINI_ALL_NUMERIC_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES
#define INFINI_8_BIT_TYPES DataType::kINT8, DataType::kUINT8
#define INFINI_16_BIT_TYPES DataType::kINT16, DataType::kUINT16, DataType::kFLOAT16, DataType::kBFLOAT16
#define INFINI_32_BIT_TYPES DataType::kINT32, DataType::kUINT32, DataType::kFLOAT32
Expand Down Expand Up @@ -242,6 +243,7 @@ auto DispatchByTypeMap(DataType dtype, Functor &&func, std::string_view context_
} \
}

CASE_FOR_TYPE(DataType::kBOOL)
CASE_FOR_TYPE(DataType::kUINT8)
CASE_FOR_TYPE(DataType::kINT8)
CASE_FOR_TYPE(DataType::kUINT16)
Expand Down Expand Up @@ -290,6 +292,7 @@ struct TypeMapDispatcher {
break; \
}

CASE_FOR_TYPE(DataType::kBOOL)
CASE_FOR_TYPE(DataType::kUINT8)
CASE_FOR_TYPE(DataType::kINT8)
CASE_FOR_TYPE(DataType::kUINT16)
Expand Down
3 changes: 2 additions & 1 deletion infini_train/src/kernels/cpu/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {
auto device = input->GetDevice();
auto dst_tensor = std::make_shared<Tensor>(input->Dims(), dtype, device);

core::cpu::DispatchCpuFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
core::cpu::DispatchCpuFunc<DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>,
DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>>(
{dtype, input->Dtype()},
[=]<typename Tdst, typename Tsrc>() {
auto dst = static_cast<Tdst *>(dst_tensor->DataPtr());
Expand Down
2 changes: 1 addition & 1 deletion infini_train/src/kernels/cpu/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace infini_train::kernels::cpu {
void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
core::cpu::DispatchCpuFunc<INFINI_ALL_TYPES>(
core::cpu::DispatchCpuFunc<INFINI_ALL_NUMERIC_TYPES>(
tensor->Dtype(),
[=]<typename T>() {
auto data = reinterpret_cast<T *>(tensor->DataPtr());
Expand Down
3 changes: 2 additions & 1 deletion infini_train/src/kernels/cuda/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ std::shared_ptr<Tensor> Cast(std::shared_ptr<Tensor> input, DataType dtype) {
dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x));
const size_t step = grid_dims.x * block_dims.x;

core::cuda::DispatchCudaFunc<DataTypeList<INFINI_ALL_TYPES>, DataTypeList<INFINI_ALL_TYPES>>(
core::cuda::DispatchCudaFunc<DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>,
DataTypeList<INFINI_ALL_NUMERIC_TYPES, INFINI_LOGICAL_TYPES>>(
{dtype, input->Dtype()},
[=]<typename Tdst, typename Tsrc>() {
auto dst = static_cast<Tdst *>(dst_tensor->DataPtr());
Expand Down
4 changes: 2 additions & 2 deletions infini_train/src/kernels/cuda/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
int threads_per_block = 256;
int num_blocks = static_cast<int>((total + threads_per_block - 1) / threads_per_block);

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
dtype,
[=, &inputs, &host_offsets]<typename T>() {
std::vector<const T *> host_input_ptrs;
Expand Down Expand Up @@ -208,7 +208,7 @@ std::vector<std::shared_ptr<Tensor>> ConcatBackward(const std::shared_ptr<Tensor
int threads_per_block = 256;
int num_blocks = static_cast<int>((total + threads_per_block - 1) / threads_per_block);

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
dtype,
[=, &grads, &host_offsets]<typename T>() {
std::vector<T *> host_ptrs;
Expand Down
28 changes: 14 additions & 14 deletions infini_train/src/kernels/cuda/elementwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ std::shared_ptr<Tensor> EqualsForward(const std::shared_ptr<Tensor> &a, const st
DISPATCH(a->Dtype(),
return BinaryForward(a, b,
[] __device__(auto x, auto y) { return (x == y) ? decltype(x){1} : decltype(x){0}; });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> EqualsScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
Expand All @@ -1033,7 +1033,7 @@ std::shared_ptr<Tensor> EqualsScalarForward(const std::shared_ptr<Tensor> &a, fl
std::shared_ptr<Tensor> LtForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
DISPATCH(a->Dtype(), return BinaryForward(
a, b, [] __device__(auto x, auto y) { return x < y ? decltype(x){1} : decltype(x){0}; });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> LtScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
Expand All @@ -1042,14 +1042,14 @@ std::shared_ptr<Tensor> LtScalarForward(const std::shared_ptr<Tensor> &a, float
return (x < static_cast<decltype(x)>(scalar)) ? decltype(x){1}
: decltype(x){0};
});
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> LeForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
DISPATCH(a->Dtype(),
return BinaryForward(a, b,
[] __device__(auto x, auto y) { return (x <= y) ? decltype(x){1} : decltype(x){0}; });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> LeScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
Expand All @@ -1058,13 +1058,13 @@ std::shared_ptr<Tensor> LeScalarForward(const std::shared_ptr<Tensor> &a, float
return (x <= static_cast<decltype(x)>(scalar)) ? decltype(x){1}
: decltype(x){0};
});
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> GtForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
DISPATCH(a->Dtype(), return BinaryForward(
a, b, [] __device__(auto x, auto y) { return x > y ? decltype(x){1} : decltype(x){0}; });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> GtScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
Expand All @@ -1073,14 +1073,14 @@ std::shared_ptr<Tensor> GtScalarForward(const std::shared_ptr<Tensor> &a, float
return (x > static_cast<decltype(x)>(scalar)) ? decltype(x){1}
: decltype(x){0};
});
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> GeForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
DISPATCH(a->Dtype(),
return BinaryForward(a, b,
[] __device__(auto x, auto y) { return (x >= y) ? decltype(x){1} : decltype(x){0}; });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> GeScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
Expand All @@ -1089,7 +1089,7 @@ std::shared_ptr<Tensor> GeScalarForward(const std::shared_ptr<Tensor> &a, float
return (x >= static_cast<decltype(x)>(scalar)) ? decltype(x){1}
: decltype(x){0};
});
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> OrForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
Expand All @@ -1098,7 +1098,7 @@ std::shared_ptr<Tensor> OrForward(const std::shared_ptr<Tensor> &a, const std::s
return (x != decltype(x){0} || y != decltype(y){0}) ? decltype(x){1}
: decltype(x){0};
});
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> AndForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
Expand All @@ -1107,7 +1107,7 @@ std::shared_ptr<Tensor> AndForward(const std::shared_ptr<Tensor> &a, const std::
return (x != decltype(x){0} && y != decltype(y){0}) ? decltype(x){1}
: decltype(x){0};
});
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> AddForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
Expand All @@ -1125,19 +1125,19 @@ std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> AddBackward(const st
std::shared_ptr<Tensor> AddScalarForward(const std::shared_ptr<Tensor> &a, float scalar) {
DISPATCH(a->Dtype(),
return UnaryForward(a, [scalar] __device__(auto x) { return Add(x, static_cast<decltype(x)>(scalar)); });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> AddScalarBackward(const std::shared_ptr<Tensor> &grad_output) {
DISPATCH(grad_output->Dtype(),
return UnaryBackward(grad_output, nullptr,
[] __device__(auto x) { return common::cuda::Cast<decltype(x)>(1); });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::shared_ptr<Tensor> SubForward(const std::shared_ptr<Tensor> &a, const std::shared_ptr<Tensor> &b) {
DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return Sub(x, y); });
, INFINI_ALL_TYPES)
, INFINI_ALL_NUMERIC_TYPES)
}

std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>> SubBackward(const std::shared_ptr<Tensor> &grad_output,
Expand Down
2 changes: 1 addition & 1 deletion infini_train/src/kernels/cuda/fill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void Fill(std::shared_ptr<Tensor> tensor, Scalar scalar) {
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
->cuda_stream();

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
tensor->Dtype(),
[=]<typename T>() {
const T casted_value = scalar.to<T>();
Expand Down
4 changes: 2 additions & 2 deletions infini_train/src/kernels/cuda/slice.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
int threads_per_block = 256;
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
dtype,
[=]<typename T>() {
SliceForwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(
Expand Down Expand Up @@ -185,7 +185,7 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
int threads_per_block = 256;
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
grad_output_dtype,
[=]<typename T>() {
SliceBackwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(
Expand Down
4 changes: 2 additions & 2 deletions infini_train/src/kernels/cuda/split.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ std::vector<std::shared_ptr<Tensor>> SplitForward(const std::shared_ptr<Tensor>
infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device))
->cuda_stream();

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
dtype,
[=]<typename T>() {
SplitForwardKernel<<<num_blocks, threads_per_block, 0, cuda_stream>>>(
Expand Down Expand Up @@ -166,7 +166,7 @@ std::shared_ptr<Tensor> SplitBackward(const std::vector<int64_t> &input_dims, in
CHECK_GE(dim, 0) << "Currently we do not support negative dimension";
CHECK_LT(dim, input_dims.size());

return core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
return core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
grad_outputs[0]->Dtype(),
[=]<typename T>() { return LaunchSplitBackward<T>(input_dims, split_size, dim, grad_outputs); },
"CUDA SplitBackward");
Expand Down
4 changes: 2 additions & 2 deletions infini_train/src/kernels/cuda/stack.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ std::shared_ptr<Tensor> StackForward(const std::vector<std::shared_ptr<Tensor>>
int threads_per_block = 256;
int num_blocks = (total + threads_per_block - 1) / threads_per_block;

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
dtype,
[=]<typename T>() {
std::vector<const T *> host_input_ptrs;
Expand Down Expand Up @@ -129,7 +129,7 @@ std::vector<std::shared_ptr<Tensor>> StackBackward(const std::vector<int64_t> &i
int threads_per_block = 256;
int num_blocks = (total + threads_per_block - 1) / threads_per_block;

core::cuda::DispatchCudaFunc<INFINI_ALL_TYPES>(
core::cuda::DispatchCudaFunc<INFINI_ALL_NUMERIC_TYPES>(
dtype,
[=]<typename T>() {
std::vector<T *> host_ptrs;
Expand Down
Loading
Loading