From 88f46bd3d873dabbe02d983f3f4c2b0eb167f13f Mon Sep 17 00:00:00 2001 From: kilinchange Date: Wed, 3 Jun 2026 02:04:56 +0000 Subject: [PATCH] feat: add bool datatype --- infini_train/include/core/backend_type_map.h | 3 +++ infini_train/include/datatype.h | 17 +++++++----- infini_train/include/dtype_dispatch.h | 5 +++- infini_train/src/kernels/cpu/cast.cc | 3 ++- infini_train/src/kernels/cpu/fill.cc | 2 +- infini_train/src/kernels/cuda/cast.cu | 3 ++- infini_train/src/kernels/cuda/concat.cu | 4 +-- infini_train/src/kernels/cuda/elementwise.cu | 28 ++++++++++---------- infini_train/src/kernels/cuda/fill.cu | 2 +- infini_train/src/kernels/cuda/slice.cu | 4 +-- infini_train/src/kernels/cuda/split.cu | 4 +-- infini_train/src/kernels/cuda/stack.cu | 4 +-- infini_train/src/kernels/cuda/transform.cu | 22 +++++++-------- infini_train/src/utils/precision_checker.cc | 2 ++ 14 files changed, 58 insertions(+), 45 deletions(-) diff --git a/infini_train/include/core/backend_type_map.h b/infini_train/include/core/backend_type_map.h index f67b8da7..38fea110 100644 --- a/infini_train/include/core/backend_type_map.h +++ b/infini_train/include/core/backend_type_map.h @@ -48,6 +48,9 @@ template struct BackendTypeMap; // ----------------------------------------------------------------------------- #define INFINI_REGISTER_STANDARD_BACKEND_TYPES(DEV) \ namespace infini_train::core { \ + template <> struct BackendTypeMap { \ + using type = bool; \ + }; \ template <> struct BackendTypeMap { \ using type = uint8_t; \ }; \ diff --git a/infini_train/include/datatype.h b/infini_train/include/datatype.h index cf637300..6efa849c 100644 --- a/infini_train/include/datatype.h +++ b/infini_train/include/datatype.h @@ -84,6 +84,7 @@ struct alignas(2) BF16 { // DataType enum and metadata tables // ----------------------------------------------------------------------------- enum class DataType : int8_t { + kBOOL, kUINT8, kINT8, kUINT16, @@ -99,16 +100,18 @@ enum class DataType : int8_t { }; inline const std::unordered_map kDataTypeToSize = { - {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, {DataType::kINT16, 2}, - {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, {DataType::kINT64, 8}, - {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, {DataType::kFLOAT64, 8}, + {DataType::kBOOL, 1}, {DataType::kUINT8, 1}, {DataType::kINT8, 1}, {DataType::kUINT16, 2}, + {DataType::kINT16, 2}, {DataType::kUINT32, 4}, {DataType::kINT32, 4}, {DataType::kUINT64, 8}, + {DataType::kINT64, 8}, {DataType::kBFLOAT16, 2}, {DataType::kFLOAT16, 2}, {DataType::kFLOAT32, 4}, + {DataType::kFLOAT64, 8}, }; inline const std::unordered_map kDataTypeToDesc = { - {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, {DataType::kUINT16, "uint16"}, - {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, {DataType::kINT32, "int32"}, - {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, {DataType::kBFLOAT16, "bf16"}, - {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, {DataType::kFLOAT64, "fp64"}, + {DataType::kBOOL, "bool"}, {DataType::kUINT8, "uint8"}, {DataType::kINT8, "int8"}, + {DataType::kUINT16, "uint16"}, {DataType::kINT16, "int16"}, {DataType::kUINT32, "uint32"}, + {DataType::kINT32, "int32"}, {DataType::kUINT64, "uint64"}, {DataType::kINT64, "int64"}, + {DataType::kBFLOAT16, "bf16"}, {DataType::kFLOAT16, "fp16"}, {DataType::kFLOAT32, "fp32"}, + {DataType::kFLOAT64, "fp64"}, }; // ============================================================================= diff --git a/infini_train/include/dtype_dispatch.h b/infini_train/include/dtype_dispatch.h index e3db38b8..8bd5054b 100644 --- a/infini_train/include/dtype_dispatch.h +++ b/infini_train/include/dtype_dispatch.h @@ -180,10 +180,11 @@ namespace infini_train { #define INFINI_FLOATING_TYPES DataType::kFLOAT32, DataType::kFLOAT64 #define INFINI_REDUCED_FLOATING_TYPES DataType::kFLOAT16, DataType::kBFLOAT16 #define INFINI_ALL_FLOATING_TYPES INFINI_FLOATING_TYPES, INFINI_REDUCED_FLOATING_TYPES +#define INFINI_LOGICAL_TYPES DataType::kBOOL #define INFINI_SIGNED_INTEGRAL_TYPES DataType::kINT8, DataType::kINT16, DataType::kINT32, DataType::kINT64 #define INFINI_UNSIGNED_INTEGRAL_TYPES DataType::kUINT8, DataType::kUINT16, DataType::kUINT32, DataType::kUINT64 #define INFINI_ALL_INTEGRAL_TYPES INFINI_SIGNED_INTEGRAL_TYPES, INFINI_UNSIGNED_INTEGRAL_TYPES -#define INFINI_ALL_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES +#define INFINI_ALL_NUMERIC_TYPES INFINI_ALL_FLOATING_TYPES, INFINI_ALL_INTEGRAL_TYPES #define INFINI_8_BIT_TYPES DataType::kINT8, DataType::kUINT8 #define INFINI_16_BIT_TYPES DataType::kINT16, DataType::kUINT16, DataType::kFLOAT16, DataType::kBFLOAT16 #define INFINI_32_BIT_TYPES DataType::kINT32, DataType::kUINT32, DataType::kFLOAT32 @@ -242,6 +243,7 @@ auto DispatchByTypeMap(DataType dtype, Functor &&func, std::string_view context_ } \ } + CASE_FOR_TYPE(DataType::kBOOL) CASE_FOR_TYPE(DataType::kUINT8) CASE_FOR_TYPE(DataType::kINT8) CASE_FOR_TYPE(DataType::kUINT16) @@ -290,6 +292,7 @@ struct TypeMapDispatcher { break; \ } + CASE_FOR_TYPE(DataType::kBOOL) CASE_FOR_TYPE(DataType::kUINT8) CASE_FOR_TYPE(DataType::kINT8) CASE_FOR_TYPE(DataType::kUINT16) diff --git a/infini_train/src/kernels/cpu/cast.cc b/infini_train/src/kernels/cpu/cast.cc index 114a5597..c3a2e595 100644 --- a/infini_train/src/kernels/cpu/cast.cc +++ b/infini_train/src/kernels/cpu/cast.cc @@ -13,7 +13,8 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { auto device = input->GetDevice(); auto dst_tensor = std::make_shared(input->Dims(), dtype, device); - core::cpu::DispatchCpuFunc, DataTypeList>( + core::cpu::DispatchCpuFunc, + DataTypeList>( {dtype, input->Dtype()}, [=]() { auto dst = static_cast(dst_tensor->DataPtr()); diff --git a/infini_train/src/kernels/cpu/fill.cc b/infini_train/src/kernels/cpu/fill.cc index 5f8b7cd3..7bcda6bd 100644 --- a/infini_train/src/kernels/cpu/fill.cc +++ b/infini_train/src/kernels/cpu/fill.cc @@ -8,7 +8,7 @@ namespace infini_train::kernels::cpu { void Fill(std::shared_ptr tensor, Scalar scalar) { - core::cpu::DispatchCpuFunc( + core::cpu::DispatchCpuFunc( tensor->Dtype(), [=]() { auto data = reinterpret_cast(tensor->DataPtr()); diff --git a/infini_train/src/kernels/cuda/cast.cu b/infini_train/src/kernels/cuda/cast.cu index 16190912..96a70ae2 100644 --- a/infini_train/src/kernels/cuda/cast.cu +++ b/infini_train/src/kernels/cuda/cast.cu @@ -34,7 +34,8 @@ std::shared_ptr Cast(std::shared_ptr input, DataType dtype) { dim3 grid_dims(CEIL_DIV(num_elements, block_dims.x)); const size_t step = grid_dims.x * block_dims.x; - core::cuda::DispatchCudaFunc, DataTypeList>( + core::cuda::DispatchCudaFunc, + DataTypeList>( {dtype, input->Dtype()}, [=]() { auto dst = static_cast(dst_tensor->DataPtr()); diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index c158a5c3..a7fa7490 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -103,7 +103,7 @@ std::shared_ptr ConcatForward(const std::vector> int threads_per_block = 256; int num_blocks = static_cast((total + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=, &inputs, &host_offsets]() { std::vector host_input_ptrs; @@ -208,7 +208,7 @@ std::vector> ConcatBackward(const std::shared_ptr((total + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=, &grads, &host_offsets]() { std::vector host_ptrs; diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index fe63e0b2..92ce9915 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -1018,7 +1018,7 @@ std::shared_ptr EqualsForward(const std::shared_ptr &a, const st DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x == y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr EqualsScalarForward(const std::shared_ptr &a, float scalar) { @@ -1033,7 +1033,7 @@ std::shared_ptr EqualsScalarForward(const std::shared_ptr &a, fl std::shared_ptr LtForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward( a, b, [] __device__(auto x, auto y) { return x < y ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LtScalarForward(const std::shared_ptr &a, float scalar) { @@ -1042,14 +1042,14 @@ std::shared_ptr LtScalarForward(const std::shared_ptr &a, float return (x < static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LeForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x <= y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr LeScalarForward(const std::shared_ptr &a, float scalar) { @@ -1058,13 +1058,13 @@ std::shared_ptr LeScalarForward(const std::shared_ptr &a, float return (x <= static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GtForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward( a, b, [] __device__(auto x, auto y) { return x > y ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GtScalarForward(const std::shared_ptr &a, float scalar) { @@ -1073,14 +1073,14 @@ std::shared_ptr GtScalarForward(const std::shared_ptr &a, float return (x > static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GeForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return (x >= y) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr GeScalarForward(const std::shared_ptr &a, float scalar) { @@ -1089,7 +1089,7 @@ std::shared_ptr GeScalarForward(const std::shared_ptr &a, float return (x >= static_cast(scalar)) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr OrForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1098,7 +1098,7 @@ std::shared_ptr OrForward(const std::shared_ptr &a, const std::s return (x != decltype(x){0} || y != decltype(y){0}) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AndForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1107,7 +1107,7 @@ std::shared_ptr AndForward(const std::shared_ptr &a, const std:: return (x != decltype(x){0} && y != decltype(y){0}) ? decltype(x){1} : decltype(x){0}; }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AddForward(const std::shared_ptr &a, const std::shared_ptr &b) { @@ -1125,19 +1125,19 @@ std::pair, std::shared_ptr> AddBackward(const st std::shared_ptr AddScalarForward(const std::shared_ptr &a, float scalar) { DISPATCH(a->Dtype(), return UnaryForward(a, [scalar] __device__(auto x) { return Add(x, static_cast(scalar)); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr AddScalarBackward(const std::shared_ptr &grad_output) { DISPATCH(grad_output->Dtype(), return UnaryBackward(grad_output, nullptr, [] __device__(auto x) { return common::cuda::Cast(1); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::shared_ptr SubForward(const std::shared_ptr &a, const std::shared_ptr &b) { DISPATCH(a->Dtype(), return BinaryForward(a, b, [] __device__(auto x, auto y) { return Sub(x, y); }); - , INFINI_ALL_TYPES) + , INFINI_ALL_NUMERIC_TYPES) } std::pair, std::shared_ptr> SubBackward(const std::shared_ptr &grad_output, diff --git a/infini_train/src/kernels/cuda/fill.cu b/infini_train/src/kernels/cuda/fill.cu index f5532779..3ddead5c 100644 --- a/infini_train/src/kernels/cuda/fill.cu +++ b/infini_train/src/kernels/cuda/fill.cu @@ -28,7 +28,7 @@ void Fill(std::shared_ptr tensor, Scalar scalar) { infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( tensor->Dtype(), [=]() { const T casted_value = scalar.to(); diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 35bd2ac5..d030d73a 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -92,7 +92,7 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { SliceForwardKernel<<>>( @@ -185,7 +185,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( grad_output_dtype, [=]() { SliceBackwardKernel<<>>( diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index f208695f..bda0dd70 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -59,7 +59,7 @@ std::vector> SplitForward(const std::shared_ptr infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { SplitForwardKernel<<>>( @@ -166,7 +166,7 @@ std::shared_ptr SplitBackward(const std::vector &input_dims, in CHECK_GE(dim, 0) << "Currently we do not support negative dimension"; CHECK_LT(dim, input_dims.size()); - return core::cuda::DispatchCudaFunc( + return core::cuda::DispatchCudaFunc( grad_outputs[0]->Dtype(), [=]() { return LaunchSplitBackward(input_dims, split_size, dim, grad_outputs); }, "CUDA SplitBackward"); diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 562fa5ec..841940ea 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -61,7 +61,7 @@ std::shared_ptr StackForward(const std::vector> int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { std::vector host_input_ptrs; @@ -129,7 +129,7 @@ std::vector> StackBackward(const std::vector &i int threads_per_block = 256; int num_blocks = (total + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { std::vector host_ptrs; diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 2bb35598..88f0e10f 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -47,7 +47,7 @@ std::shared_ptr TrilForward(const std::shared_ptr &input, int64_ infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { TrilForwardKernel<<>>( @@ -90,7 +90,7 @@ std::shared_ptr TrilBackward(const std::shared_ptr &grad_output, infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -135,7 +135,7 @@ std::shared_ptr TriuForward(const std::shared_ptr &input, int64_ infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { TriuForwardKernel<<>>( @@ -177,7 +177,7 @@ std::shared_ptr TriuBackward(const std::shared_ptr &grad_output, infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -269,7 +269,7 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { output->Fill(0.0); @@ -371,7 +371,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const int64_t inner = input->NumElements() / rows; int num_blocks = static_cast((input->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { MaskLeadsForwardKernel<<>>( @@ -384,7 +384,7 @@ std::shared_ptr MaskForward(const std::shared_ptr &input, const int64_t batch_size = input->NumElements() / mask_size; int num_blocks = static_cast((input->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { MaskForwardKernel<<>>( @@ -435,7 +435,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, int64_t inner = grad_output->NumElements() / rows; int num_blocks = static_cast((grad_output->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -449,7 +449,7 @@ std::shared_ptr MaskBackward(const std::shared_ptr &grad_output, int64_t batch_size = grad_output->NumElements() / mask_size; int num_blocks = static_cast((grad_output->NumElements() + threads_per_block - 1) / threads_per_block); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( dtype, [=]() { grad_input->Fill(0.0); @@ -504,7 +504,7 @@ std::shared_ptr RepeatInterleaveForward(const std::shared_ptr &i infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( input->Dtype(), [=]() { RepeatInterleaveForwardKernel<<>>( @@ -562,7 +562,7 @@ std::shared_ptr RepeatInterleaveBackward(const std::shared_ptr & infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - core::cuda::DispatchCudaFunc( + core::cuda::DispatchCudaFunc( grad_output->Dtype(), [=]() { grad_input->Fill(0.0); diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index d2cbd16a..2965284e 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -193,6 +193,8 @@ std::string FormatShape(const std::vector &shape) { std::string DataTypeToString(DataType dtype) { switch (dtype) { + case DataType::kBOOL: + return "bool"; case DataType::kFLOAT32: return "float32"; case DataType::kFLOAT16: