Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/infiniop/ops/add/kunlun/add_kunlun.xpu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ infiniStatus_t Descriptor::create(
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();

CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32);
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16, INFINI_DTYPE_I32, INFINI_DTYPE_I64);

CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);

Expand Down Expand Up @@ -51,6 +51,8 @@ infiniStatus_t Descriptor::calculate(
return _device_info->calculate<8, AddOp, float>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I32:
return _device_info->calculate<8, AddOp, int32_t>(_info, workspace, output, inputs, stream);
case INFINI_DTYPE_I64:
return _device_info->calculate<8, AddOp, int64_t>(_info, workspace, output, inputs, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
Expand Down
8 changes: 7 additions & 1 deletion src/infiniop/ops/add/kunlun/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ typedef struct AddOp {
T b = inputs[1];
return a + b;
}
// bfloat16 特化版本(使用 float 计算精度)
// bfloat16 - cast to flloat
inline __device__ bfloat16_t operator()(const bfloat16_t *inputs) const {
float a_f = __bfloat162float(inputs[0]);
float b_f = __bfloat162float(inputs[1]);
return __float2bfloat16(a_f + b_f);
}
// int64_t - cast to int32_t
inline __device__ int64_t operator()(const int64_t *inputs) const {
int32_t a = static_cast<int32_t>(inputs[0]);
int32_t b = static_cast<int32_t>(inputs[1]);
return static_cast<int64_t>(a + b);
}
} AddOp;
} // namespace op::add::kunlun

Expand Down
Loading