From 9a4fec744efc86bc57068ebabb72269dea592cc8 Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Fri, 5 Jun 2026 12:32:04 +0800 Subject: [PATCH] feat(operator):support nvidia-embedding-operator. --- src/base/embedding.h | 93 ++++++++++ src/native/cuda/nvidia/ops/embedding/kernel.h | 21 +++ src/native/cuda/ops/embedding/kernel.cuh | 171 ++++++++++++++++++ src/native/cuda/ops/embedding/kernel.h | 131 ++++++++++++++ tests/test_embedding.py | 99 ++++++++++ 5 files changed, 515 insertions(+) create mode 100644 src/base/embedding.h create mode 100644 src/native/cuda/nvidia/ops/embedding/kernel.h create mode 100644 src/native/cuda/ops/embedding/kernel.cuh create mode 100644 src/native/cuda/ops/embedding/kernel.h create mode 100644 tests/test_embedding.py diff --git a/src/base/embedding.h b/src/base/embedding.h new file mode 100644 index 000000000..b83dfda78 --- /dev/null +++ b/src/base/embedding.h @@ -0,0 +1,93 @@ +#ifndef INFINI_OPS_BASE_EMBEDDING_H_ +#define INFINI_OPS_BASE_EMBEDDING_H_ + +#include +#include + +#include "data_type.h" +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +// Aligned with InfiniCore and `torch.nn.functional.embedding`. + +class Embedding : public Operator { + public: + Embedding(const Tensor input, const Tensor weight, Tensor out) + : input_shape_{input.shape()}, + weight_shape_{weight.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + weight_strides_{weight.strides()}, + out_strides_{out.strides()}, + input_dtype_{input.dtype()}, + weight_dtype_{weight.dtype()}, + out_dtype_{out.dtype()}, + num_indices_{NumIndices(input_shape_)}, + vocab_size_{weight.size(0)}, + embedding_dim_{weight.size(1)} { + assert(weight.ndim() == 2 && "`Embedding` requires 2D `weight`"); + assert(out.ndim() == input.ndim() + 1 && + "`Embedding` output rank must be input rank + 1"); + + for (Tensor::Size i = 0; i < input.ndim(); ++i) { + assert( + out.size(i) == input.size(i) && + "`Embedding` output shape must match input shape on non-last dims"); + } + + assert(out.size(-1) == embedding_dim_ && + "`Embedding` output last dim must equal `weight` embedding dim"); + + assert(input_dtype_ == DataType::kInt32 || + input_dtype_ == DataType::kInt64); + assert(weight_dtype_ == DataType::kFloat32 || + weight_dtype_ == DataType::kFloat16 || + weight_dtype_ == DataType::kBFloat16); + assert(out_dtype_ == weight_dtype_ && + "`Embedding` output dtype must match `weight` dtype"); + } + + virtual void operator()(const Tensor input, const Tensor weight, + Tensor out) const = 0; + + protected: + static Tensor::Size NumIndices(const Tensor::Shape& input_shape) { + Tensor::Size num_indices = 1; + + for (Tensor::Size dim : input_shape) { + num_indices *= dim; + } + + return num_indices; + } + + Tensor::Shape input_shape_; + + Tensor::Shape weight_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides weight_strides_; + + Tensor::Strides out_strides_; + + DataType input_dtype_; + + DataType weight_dtype_; + + DataType out_dtype_; + + Tensor::Size num_indices_{0}; + + Tensor::Size vocab_size_{0}; + + Tensor::Size embedding_dim_{0}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/nvidia/ops/embedding/kernel.h b/src/native/cuda/nvidia/ops/embedding/kernel.h new file mode 100644 index 000000000..f05f155ae --- /dev/null +++ b/src/native/cuda/nvidia/ops/embedding/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_NVIDIA_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_NVIDIA_EMBEDDING_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/embedding/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaEmbedding> { + public: + using CudaEmbedding>::CudaEmbedding; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/embedding/kernel.cuh b/src/native/cuda/ops/embedding/kernel.cuh new file mode 100644 index 000000000..613cc531e --- /dev/null +++ b/src/native/cuda/ops/embedding/kernel.cuh @@ -0,0 +1,171 @@ +#ifndef INFINI_OPS_CUDA_EMBEDDING_KERNEL_CUH_ +#define INFINI_OPS_CUDA_EMBEDDING_KERNEL_CUH_ + +#include +#include +#include + +#include "native/cuda/kernel_commons.cuh" + +namespace infini::ops { +namespace embedding_detail { + +__forceinline__ __device__ bool IsAligned(const void* ptr, size_t alignment) { + return (reinterpret_cast(ptr) % alignment) == 0; +} + +template +__forceinline__ __device__ void CopyScalar(T* __restrict__ dst, + const T* __restrict__ src, + size_t embedding_dim, + ptrdiff_t dst_col_stride = 1, + ptrdiff_t src_col_stride = 1) { + for (size_t i = 0; i < embedding_dim; ++i) { + dst[i * dst_col_stride] = __ldg(&src[i * src_col_stride]); + } +} + +// Same as `third_party/InfiniCore/.../embedding/cuda/embedding_kernel.cuh`. +template +__forceinline__ __device__ void CopyVectorizedFloat4( + float* __restrict__ dst, const float* __restrict__ src, + size_t embedding_dim) { + const float4* src_vec = reinterpret_cast(src); + float4* dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 4; + + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + size_t remaining = embedding_dim % 4; + if (remaining > 0) { + size_t offset = vec_count * 4; + for (size_t i = 0; i < remaining; ++i) { + dst[offset + i] = __ldg(&src[offset + i]); + } + } +} + +template +__forceinline__ __device__ void CopyVectorizedFloat2( + float* __restrict__ dst, const float* __restrict__ src, + size_t embedding_dim) { + const float2* src_vec = reinterpret_cast(src); + float2* dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +template +__forceinline__ __device__ void CopyVectorized16(T* __restrict__ dst, + const T* __restrict__ src, + size_t embedding_dim) { + const uint32_t* src_vec = reinterpret_cast(src); + uint32_t* dst_vec = reinterpret_cast(dst); + size_t vec_count = embedding_dim / 2; + + for (size_t i = 0; i < vec_count; ++i) { + dst_vec[i] = __ldg(&src_vec[i]); + } + + if (embedding_dim % 2 != 0) { + dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]); + } +} + +// Contiguous row copy with InfiniCore vectorization strategy. +template +__forceinline__ __device__ void CopyRowContiguous(T* __restrict__ dst, + const T* __restrict__ src, + size_t embedding_dim) { + if constexpr (std::is_same_v) { + bool aligned_16 = IsAligned(src, 16) && IsAligned(dst, 16); + if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) { + CopyVectorizedFloat4(dst, src, embedding_dim); + } else if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + CopyVectorizedFloat2(dst, src, embedding_dim); + } else { + CopyScalar(dst, src, embedding_dim); + } + } else if constexpr (IsFP16 || IsBFloat16) { + if (embedding_dim >= 2 && embedding_dim % 2 == 0) { + CopyVectorized16(dst, src, embedding_dim); + } else { + CopyScalar(dst, src, embedding_dim); + } + } else { + CopyScalar(dst, src, embedding_dim); + } +} + +template +__forceinline__ __device__ void CopyRow(T* __restrict__ dst, + const T* __restrict__ src, + size_t embedding_dim, + ptrdiff_t dst_col_stride, + ptrdiff_t src_col_stride) { + if (dst_col_stride == 1 && src_col_stride == 1) { + CopyRowContiguous(dst, src, embedding_dim); + return; + } + + CopyScalar(dst, src, embedding_dim, dst_col_stride, src_col_stride); +} + +} // namespace embedding_detail + +template +__global__ void EmbeddingKernel( + T* __restrict__ output, const IndexT* __restrict__ indices, + const T* __restrict__ weight, size_t num_indices, size_t input_ndim, + const size_t* __restrict__ input_shape, + const ptrdiff_t* __restrict__ input_strides, size_t out_ndim, + const size_t* __restrict__ out_shape, + const ptrdiff_t* __restrict__ out_strides, ptrdiff_t weight_row_stride, + ptrdiff_t weight_col_stride, size_t embedding_dim, size_t vocab_size, + bool input_contiguous, bool out_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx >= num_indices) { + return; + } + + size_t input_offset = + input_contiguous + ? idx + : IndexToOffset(idx, input_ndim, input_shape, input_strides); + IndexT index_val = __ldg(&indices[input_offset]); + + if (index_val < 0 || static_cast(index_val) >= vocab_size) { + return; + } + + const T* src = weight + static_cast(index_val) * weight_row_stride; + + if (out_contiguous) { + T* dst = output + idx * embedding_dim; + embedding_detail::CopyRow(dst, src, embedding_dim, 1, + weight_col_stride); + return; + } + + size_t out_prefix_ndim = out_ndim > 0 ? out_ndim - 1 : 0; + size_t out_row_offset = + IndexToOffset(idx, out_prefix_ndim, out_shape, out_strides); + ptrdiff_t out_col_stride = out_strides[out_ndim - 1]; + + embedding_detail::CopyRow(output + out_row_offset, src, embedding_dim, + out_col_stride, weight_col_stride); +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/embedding/kernel.h b/src/native/cuda/ops/embedding/kernel.h new file mode 100644 index 000000000..25b244858 --- /dev/null +++ b/src/native/cuda/ops/embedding/kernel.h @@ -0,0 +1,131 @@ +#ifndef INFINI_OPS_CUDA_EMBEDDING_KERNEL_H_ +#define INFINI_OPS_CUDA_EMBEDDING_KERNEL_H_ + +#include +#include +#include +#include + +#include "base/embedding.h" +#include "common/generic_utils.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/embedding/kernel.cuh" + +namespace infini::ops { + +template +class CudaEmbedding : public Embedding { + public: + CudaEmbedding(const Tensor input, const Tensor weight, Tensor out) + : Embedding{input, weight, out}, + input_ndim_{input.ndim()}, + out_ndim_{out.ndim()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()}, + weight_row_stride_{weight.stride(0)}, + weight_col_stride_{weight.stride(1)} { + size_t input_shape_size = input_ndim_ * sizeof(*d_input_shape_); + size_t input_strides_size = input_ndim_ * sizeof(*d_input_strides_); + size_t out_shape_size = out_ndim_ * sizeof(*d_out_shape_); + size_t out_strides_size = out_ndim_ * sizeof(*d_out_strides_); + const size_t metadata_size = input_shape_size + input_strides_size + + out_shape_size + out_strides_size; + + std::vector metadata(metadata_size); + + Backend::Malloc((void**)&d_metadata_, metadata_size); + + size_t offset = 0; + d_input_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_shape_.data(), + input_shape_size); + offset += input_shape_size; + + d_input_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_strides_.data(), + input_strides_size); + offset += input_strides_size; + + d_out_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_shape_.data(), out_shape_size); + offset += out_shape_size; + + d_out_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_strides_.data(), + out_strides_size); + + Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, + Backend::MemcpyHostToDevice); + } + + ~CudaEmbedding() { Backend::Free(d_metadata_); } + + void operator()(const Tensor input, const Tensor weight, + Tensor out) const override { + if (num_indices_ == 0) { + return; + } + + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + size_t block_size = 256; + if (embedding_dim_ <= 64) { + block_size = 512; + } else if (embedding_dim_ >= 1024) { + block_size = 128; + } + + size_t grid_size = utils::CeilDiv(num_indices_, block_size); + + DispatchFunc, + ConcatType, ReducedFloatTypes>>( + {static_cast(input_dtype_), + static_cast(weight_dtype_)}, + [&](auto list_tag) { + using IndexT = + TypeMapType(list_tag)>; + using T = TypeMapType(list_tag)>; + + EmbeddingKernel + <<>>( + reinterpret_cast(out.data()), + reinterpret_cast(input.data()), + reinterpret_cast(weight.data()), num_indices_, + input_ndim_, d_input_shape_, d_input_strides_, out_ndim_, + d_out_shape_, d_out_strides_, weight_row_stride_, + weight_col_stride_, embedding_dim_, vocab_size_, + is_input_contiguous_, is_out_contiguous_); + }, + "CudaEmbedding::operator()"); + } + + private: + Tensor::Size input_ndim_{0}; + + Tensor::Size out_ndim_{0}; + + bool is_input_contiguous_{false}; + + bool is_out_contiguous_{false}; + + Tensor::Stride weight_row_stride_{0}; + + Tensor::Stride weight_col_stride_{0}; + + std::byte* d_metadata_{nullptr}; + + Tensor::Size* d_input_shape_{nullptr}; + + Tensor::Stride* d_input_strides_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 000000000..98fe6bad3 --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,99 @@ +import infini.ops +import pytest +import torch + +from tests.utils import ( + Payload, + empty_strided, + get_stream, + randint_strided, + randn_strided, +) + + +# Format: +# (input_shape, weight_shape, input_strides, weight_strides, out_strides, input_dtype) +_TEST_CASES = ( + ((1, 5), (32000, 4), None, None, None, torch.int64), + ((2, 10), (32000, 2048), None, None, None, torch.int32), + ((1, 5), (10, 10), None, None, None, torch.int64), + ((2, 4), (32, 8), None, None, None, torch.int64), + ((2, 4), (32, 8), (8, 1), None, (32, 8, 1), torch.int32), + ((2, 4), (32, 8), None, (1, 32), None, torch.int64), +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "input_shape, weight_shape, input_strides, weight_strides, out_strides, input_dtype", + _TEST_CASES, +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-3, 0.0), + (torch.float16, 1e-2, 0.0), + (torch.bfloat16, 5e-2, 0.0), + ), +) +def test_embedding( + input_shape, + weight_shape, + input_strides, + weight_strides, + out_strides, + input_dtype, + implementation_index, + dtype, + device, + rtol, + atol, +): + vocab_size = weight_shape[0] + embedding_dim = weight_shape[1] + output_shape = (*input_shape, embedding_dim) + + input = randint_strided( + 1, + min(9, vocab_size), + input_shape, + input_strides, + dtype=input_dtype, + device=device, + ) + weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) + out = empty_strided(output_shape, out_strides, dtype=dtype, device=device) + + return Payload( + lambda *args, **kwargs: _embedding( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_embedding, + (input, weight), + {"out": out}, + rtol=rtol, + atol=atol, + ) + + +def _embedding(input, weight, *, out=None, implementation_index=0): + infini.ops.embedding( + input, + weight, + out, + implementation_index=implementation_index, + stream=get_stream(input.device), + ) + + return out + + +def _torch_embedding(input, weight, *, out=None): + result = torch.nn.functional.embedding(input, weight) + + if out is not None: + out.copy_(result) + else: + out = result + + return out