Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/base/embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#ifndef INFINI_OPS_BASE_EMBEDDING_H_
#define INFINI_OPS_BASE_EMBEDDING_H_

#include <cassert>
#include <cstddef>

#include "data_type.h"
#include "operator.h"
#include "tensor.h"

namespace infini::ops {

// Aligned with InfiniCore and `torch.nn.functional.embedding`.

class Embedding : public Operator<Embedding> {
public:
Embedding(const Tensor input, const Tensor weight, Tensor out)
: input_shape_{input.shape()},
weight_shape_{weight.shape()},
out_shape_{out.shape()},
input_strides_{input.strides()},
weight_strides_{weight.strides()},
out_strides_{out.strides()},
input_dtype_{input.dtype()},
weight_dtype_{weight.dtype()},
out_dtype_{out.dtype()},
num_indices_{NumIndices(input_shape_)},
vocab_size_{weight.size(0)},
embedding_dim_{weight.size(1)} {
assert(weight.ndim() == 2 && "`Embedding` requires 2D `weight`");
assert(out.ndim() == input.ndim() + 1 &&
"`Embedding` output rank must be input rank + 1");

for (Tensor::Size i = 0; i < input.ndim(); ++i) {
assert(
out.size(i) == input.size(i) &&
"`Embedding` output shape must match input shape on non-last dims");
}

assert(out.size(-1) == embedding_dim_ &&
"`Embedding` output last dim must equal `weight` embedding dim");

assert(input_dtype_ == DataType::kInt32 ||
input_dtype_ == DataType::kInt64);
assert(weight_dtype_ == DataType::kFloat32 ||
weight_dtype_ == DataType::kFloat16 ||
weight_dtype_ == DataType::kBFloat16);
assert(out_dtype_ == weight_dtype_ &&
"`Embedding` output dtype must match `weight` dtype");
}

virtual void operator()(const Tensor input, const Tensor weight,
Tensor out) const = 0;

protected:
static Tensor::Size NumIndices(const Tensor::Shape& input_shape) {
Tensor::Size num_indices = 1;

for (Tensor::Size dim : input_shape) {
num_indices *= dim;
}

return num_indices;
}

Tensor::Shape input_shape_;

Tensor::Shape weight_shape_;

Tensor::Shape out_shape_;

Tensor::Strides input_strides_;

Tensor::Strides weight_strides_;

Tensor::Strides out_strides_;

DataType input_dtype_;

DataType weight_dtype_;

DataType out_dtype_;

Tensor::Size num_indices_{0};

Tensor::Size vocab_size_{0};

Tensor::Size embedding_dim_{0};
};

} // namespace infini::ops

#endif
21 changes: 21 additions & 0 deletions src/native/cuda/nvidia/ops/embedding/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_NVIDIA_EMBEDDING_KERNEL_H_
#define INFINI_OPS_NVIDIA_EMBEDDING_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/embedding/kernel.h"

namespace infini::ops {

template <>
class Operator<Embedding, Device::Type::kNvidia>
: public CudaEmbedding<Runtime<Device::Type::kNvidia>> {
public:
using CudaEmbedding<Runtime<Device::Type::kNvidia>>::CudaEmbedding;
};

} // namespace infini::ops

#endif
171 changes: 171 additions & 0 deletions src/native/cuda/ops/embedding/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
#ifndef INFINI_OPS_CUDA_EMBEDDING_KERNEL_CUH_
#define INFINI_OPS_CUDA_EMBEDDING_KERNEL_CUH_

#include <cstddef>
#include <cstdint>
#include <type_traits>

#include "native/cuda/kernel_commons.cuh"

namespace infini::ops {
namespace embedding_detail {

__forceinline__ __device__ bool IsAligned(const void* ptr, size_t alignment) {
return (reinterpret_cast<size_t>(ptr) % alignment) == 0;
}

template <typename T>
__forceinline__ __device__ void CopyScalar(T* __restrict__ dst,
const T* __restrict__ src,
size_t embedding_dim,
ptrdiff_t dst_col_stride = 1,
ptrdiff_t src_col_stride = 1) {
for (size_t i = 0; i < embedding_dim; ++i) {
dst[i * dst_col_stride] = __ldg(&src[i * src_col_stride]);
}
}

// Same as `third_party/InfiniCore/.../embedding/cuda/embedding_kernel.cuh`.
template <typename T>
__forceinline__ __device__ void CopyVectorizedFloat4(
float* __restrict__ dst, const float* __restrict__ src,
size_t embedding_dim) {
const float4* src_vec = reinterpret_cast<const float4*>(src);
float4* dst_vec = reinterpret_cast<float4*>(dst);
size_t vec_count = embedding_dim / 4;

for (size_t i = 0; i < vec_count; ++i) {
dst_vec[i] = __ldg(&src_vec[i]);
}

size_t remaining = embedding_dim % 4;
if (remaining > 0) {
size_t offset = vec_count * 4;
for (size_t i = 0; i < remaining; ++i) {
dst[offset + i] = __ldg(&src[offset + i]);
}
}
}

template <typename T>
__forceinline__ __device__ void CopyVectorizedFloat2(
float* __restrict__ dst, const float* __restrict__ src,
size_t embedding_dim) {
const float2* src_vec = reinterpret_cast<const float2*>(src);
float2* dst_vec = reinterpret_cast<float2*>(dst);
size_t vec_count = embedding_dim / 2;

for (size_t i = 0; i < vec_count; ++i) {
dst_vec[i] = __ldg(&src_vec[i]);
}

if (embedding_dim % 2 != 0) {
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
}
}

template <Device::Type kDev, typename T>
__forceinline__ __device__ void CopyVectorized16(T* __restrict__ dst,
const T* __restrict__ src,
size_t embedding_dim) {
const uint32_t* src_vec = reinterpret_cast<const uint32_t*>(src);
uint32_t* dst_vec = reinterpret_cast<uint32_t*>(dst);
size_t vec_count = embedding_dim / 2;

for (size_t i = 0; i < vec_count; ++i) {
dst_vec[i] = __ldg(&src_vec[i]);
}

if (embedding_dim % 2 != 0) {
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
}
}

// Contiguous row copy with InfiniCore vectorization strategy.
template <Device::Type kDev, typename T>
__forceinline__ __device__ void CopyRowContiguous(T* __restrict__ dst,
const T* __restrict__ src,
size_t embedding_dim) {
if constexpr (std::is_same_v<T, float>) {
bool aligned_16 = IsAligned(src, 16) && IsAligned(dst, 16);
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
CopyVectorizedFloat4<T>(dst, src, embedding_dim);
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
CopyVectorizedFloat2<T>(dst, src, embedding_dim);
} else {
CopyScalar(dst, src, embedding_dim);
}
} else if constexpr (IsFP16<kDev, T> || IsBFloat16<kDev, T>) {
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
CopyVectorized16<kDev, T>(dst, src, embedding_dim);
} else {
CopyScalar(dst, src, embedding_dim);
}
} else {
CopyScalar(dst, src, embedding_dim);
}
}

template <Device::Type kDev, typename T>
__forceinline__ __device__ void CopyRow(T* __restrict__ dst,
const T* __restrict__ src,
size_t embedding_dim,
ptrdiff_t dst_col_stride,
ptrdiff_t src_col_stride) {
if (dst_col_stride == 1 && src_col_stride == 1) {
CopyRowContiguous<kDev>(dst, src, embedding_dim);
return;
}

CopyScalar(dst, src, embedding_dim, dst_col_stride, src_col_stride);
}

} // namespace embedding_detail

template <Device::Type kDev, typename T, typename IndexT>
__global__ void EmbeddingKernel(
T* __restrict__ output, const IndexT* __restrict__ indices,
const T* __restrict__ weight, size_t num_indices, size_t input_ndim,
const size_t* __restrict__ input_shape,
const ptrdiff_t* __restrict__ input_strides, size_t out_ndim,
const size_t* __restrict__ out_shape,
const ptrdiff_t* __restrict__ out_strides, ptrdiff_t weight_row_stride,
ptrdiff_t weight_col_stride, size_t embedding_dim, size_t vocab_size,
bool input_contiguous, bool out_contiguous) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

if (idx >= num_indices) {
return;
}

size_t input_offset =
input_contiguous
? idx
: IndexToOffset(idx, input_ndim, input_shape, input_strides);
IndexT index_val = __ldg(&indices[input_offset]);

if (index_val < 0 || static_cast<size_t>(index_val) >= vocab_size) {
return;
}

const T* src = weight + static_cast<size_t>(index_val) * weight_row_stride;

if (out_contiguous) {
T* dst = output + idx * embedding_dim;
embedding_detail::CopyRow<kDev>(dst, src, embedding_dim, 1,
weight_col_stride);
return;
}

size_t out_prefix_ndim = out_ndim > 0 ? out_ndim - 1 : 0;
size_t out_row_offset =
IndexToOffset(idx, out_prefix_ndim, out_shape, out_strides);
ptrdiff_t out_col_stride = out_strides[out_ndim - 1];

embedding_detail::CopyRow<kDev>(output + out_row_offset, src, embedding_dim,
out_col_stride, weight_col_stride);
}

} // namespace infini::ops

#endif
Loading
Loading