diff --git a/src/infiniop/ops/add/kunlun/add_kunlun.xpu b/src/infiniop/ops/add/kunlun/add_kunlun.xpu index 44d534762..a27f065ed 100644 --- a/src/infiniop/ops/add/kunlun/add_kunlun.xpu +++ b/src/infiniop/ops/add/kunlun/add_kunlun.xpu @@ -21,7 +21,7 @@ infiniStatus_t Descriptor::create( const auto &a_shape = a_desc->shape(); const auto &b_shape = b_desc->shape(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64); CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); @@ -51,6 +51,8 @@ infiniStatus_t Descriptor::calculate( return _device_info->calculate<8, AddOp, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_I32: return _device_info->calculate<8, AddOp, int32_t>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_I64: + return _device_info->calculate<8, AddOp, int64_t>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/add/kunlun/kernel.h b/src/infiniop/ops/add/kunlun/kernel.h index 984a00afa..43094af9b 100644 --- a/src/infiniop/ops/add/kunlun/kernel.h +++ b/src/infiniop/ops/add/kunlun/kernel.h @@ -12,12 +12,18 @@ typedef struct AddOp { T b = inputs[1]; return a + b; } - // bfloat16 特化版本(使用 float 计算精度) + // bfloat16 - cast to flloat inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const { float a_f = __bfloat162float(inputs[0]); float b_f = __bfloat162float(inputs[1]); return __float2bfloat16(a_f + b_f); } + // int64_t - cast to int32_t + inline __device__ int64_t operator()(const int64_t *inputs) const { + int32_t a = static_cast(inputs[0]); + int32_t b = static_cast(inputs[1]); + return static_cast(a + b); + } } AddOp; } // namespace op::add::kunlun