Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions src/base/add_rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,53 @@ class AddRmsNorm : public Operator<AddRmsNorm> {
public:
// TODO: Make `eps` an `std::optional<float>` with a PyTorch-aligned default.
// Also consider the same change for `RmsNorm`.
AddRmsNorm(const Tensor input, const Tensor other, const Tensor weight,
float eps, Tensor out, Tensor rstd_out)
AddRmsNorm(const Tensor input, const Tensor residual, const Tensor weight,
float eps, Tensor out, Tensor residual_out)
: input_shape_{input.shape()},
out_shape_{out.shape()},
input_strides_{input.strides()},
residual_strides_{residual.strides()},
out_strides_{out.strides()},
residual_out_strides_{residual_out.strides()},
eps_{eps},
dim_{input.size(-1)},
ndim_{input.ndim()},
batch_size_{ndim_ == 2 ? input.size(-2) : input.size(-3)},
nhead_{ndim_ == 2 ? 1 : input.size(-2)},
rstd_shape_{static_cast<int64_t>(batch_size_),
static_cast<int64_t>(nhead_)} {
assert(input.dtype() == other.dtype());
dim_{out.size(-1)},
ndim_{out.ndim()},
batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)},
nhead_{ndim_ == 2 ? 1 : out.size(-2)} {
assert(ndim_ == 2 || ndim_ == 3);
assert(input.shape() == out.shape());
assert(input.shape() == residual.shape());
assert(input.shape() == residual_out.shape());
assert(weight.ndim() == 1 && weight.size(-1) == dim_);
assert(input.dtype() == out.dtype());
assert(input.dtype() == rstd_out.dtype());
assert(input.dtype() == residual.dtype());
assert(input.dtype() == residual_out.dtype());
assert(input.dtype() == weight.dtype());
// CUDA kernel indexes the normalized dimension with stride 1.
assert(input.stride(-1) == 1);
assert(residual.stride(-1) == 1);
assert(out.stride(-1) == 1);
assert(residual_out.stride(-1) == 1);
assert(weight.stride(-1) == 1);
}

virtual void operator()(const Tensor input, const Tensor other,
virtual void operator()(const Tensor input, const Tensor residual,
const Tensor weight, float eps, Tensor out,
Tensor rstd_out) const = 0;
Tensor residual_out) const = 0;

protected:
Tensor::Shape input_shape_;

Tensor::Shape out_shape_;

Tensor::Strides input_strides_;

Tensor::Strides residual_strides_;

Tensor::Strides out_strides_;

Tensor::Strides residual_out_strides_;

float eps_{1e-6f};

Tensor::Size dim_{0};
Expand All @@ -44,8 +69,6 @@ class AddRmsNorm : public Operator<AddRmsNorm> {
Tensor::Size batch_size_{0};

Tensor::Size nhead_{1};

std::vector<int64_t> rstd_shape_;
};

} // namespace infini::ops
Expand Down
21 changes: 21 additions & 0 deletions src/native/cuda/nvidia/ops/add_rms_norm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#ifndef INFINI_OPS_NVIDIA_ADD_RMS_NORM_KERNEL_H_
#define INFINI_OPS_NVIDIA_ADD_RMS_NORM_KERNEL_H_

#include <utility>

#include "native/cuda/nvidia/caster.cuh"
#include "native/cuda/nvidia/runtime_.h"
#include "native/cuda/ops/add_rms_norm/kernel.h"

namespace infini::ops {

template <>
class Operator<AddRmsNorm, Device::Type::kNvidia>
: public CudaAddRmsNorm<Runtime<Device::Type::kNvidia>> {
public:
using CudaAddRmsNorm<Runtime<Device::Type::kNvidia>>::CudaAddRmsNorm;
};

} // namespace infini::ops

#endif
80 changes: 80 additions & 0 deletions src/native/cuda/ops/add_rms_norm/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#ifndef INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_CUH_
#define INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_CUH_

#include <cstddef>
#include <cstdint>
#include <cub/block/block_reduce.cuh>

#include "native/cuda/caster.cuh"
#include "native/cuda/kernel_commons.cuh"

namespace infini::ops {
namespace add_rms_norm_detail {

// Same as `native/cuda/ops/rms_norm/kernel.cuh`.
template <unsigned int block_size, Device::Type kDev, typename TData,
typename TCompute>
__device__ __forceinline__ TCompute SumSquared(const TData* data_ptr,
size_t count) {
TCompute ss = 0;
for (size_t i = threadIdx.x; i < count; i += block_size) {
TCompute value = Caster<kDev>::template Cast<TCompute>(data_ptr[i]);
ss += value * value;
}
using BlockReduce = cub::BlockReduce<TCompute, block_size>;
__shared__ typename BlockReduce::TempStorage temp_storage;
return BlockReduce(temp_storage).Sum(ss);
}

} // namespace add_rms_norm_detail

template <unsigned int block_size, Device::Type kDev, typename TCompute,
typename TData, typename TWeight>
__global__ void AddRmsNormKernel(
TData* __restrict__ y, int64_t stride_y_batch, int64_t stride_y_nhead,
TData* __restrict__ residual_out, int64_t stride_residual_out_batch,
int64_t stride_residual_out_nhead, const TData* __restrict__ input,
int64_t stride_input_batch, int64_t stride_input_nhead,
const TData* __restrict__ residual, int64_t stride_residual_batch,
int64_t stride_residual_nhead, const TWeight* __restrict__ w, size_t nhead,
size_t dim, float epsilon) {
size_t batch_idx = blockIdx.x / nhead;
size_t head_idx = blockIdx.x % nhead;

auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead;
auto input_ptr =
input + batch_idx * stride_input_batch + head_idx * stride_input_nhead;
auto residual_ptr = residual + batch_idx * stride_residual_batch +
head_idx * stride_residual_nhead;
auto w_ptr = w;
auto residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch +
head_idx * stride_residual_out_nhead;

for (size_t i = threadIdx.x; i < dim; i += block_size) {
TCompute sum_val = Caster<kDev>::template Cast<TCompute>(input_ptr[i]) +
Caster<kDev>::template Cast<TCompute>(residual_ptr[i]);
residual_out_ptr[i] = Caster<kDev>::template Cast<TData>(sum_val);
}

TCompute sum_squared =
add_rms_norm_detail::SumSquared<block_size, kDev, TData, TCompute>(
residual_out_ptr, dim);

__shared__ TCompute rms;
if (threadIdx.x == 0) {
rms = Caster<kDev>::template Cast<TCompute>(rsqrtf(
sum_squared / Caster<kDev>::template Cast<TCompute>(dim) + epsilon));
}
__syncthreads();

for (size_t i = threadIdx.x; i < dim; i += block_size) {
TCompute sum_val =
Caster<kDev>::template Cast<TCompute>(residual_out_ptr[i]);
y_ptr[i] = Caster<kDev>::template Cast<TData>(
sum_val * Caster<kDev>::template Cast<TCompute>(w_ptr[i]) * rms);
}
}

} // namespace infini::ops

#endif
76 changes: 76 additions & 0 deletions src/native/cuda/ops/add_rms_norm/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#ifndef INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_H_
#define INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_H_

#include <cassert>
#include <cstdint>

#include "base/add_rms_norm.h"
#include "data_type.h"
#include "dispatcher.h"
#include "native/cuda/kernel_commons.cuh"
#include "native/cuda/ops/add_rms_norm/kernel.cuh"
#include "native/cuda/runtime_utils.h"

namespace infini::ops {

template <typename Backend>
class CudaAddRmsNorm : public AddRmsNorm {
public:
using AddRmsNorm::AddRmsNorm;

void operator()(const Tensor input, const Tensor residual,
const Tensor weight, float eps, Tensor out,
Tensor residual_out) const override {
auto cuda_stream =
static_cast<typename Backend::Stream>(stream_ ? stream_ : 0);

auto stride_input_batch = input_strides_.size() > 1 ? input_strides_[0] : 0;
auto stride_input_nhead =
input_strides_.size() > 1 ? input_strides_[1] : input_strides_[0];
auto stride_residual_batch =
residual_strides_.size() > 1 ? residual_strides_[0] : 0;
auto stride_residual_nhead = residual_strides_.size() > 1
? residual_strides_[1]
: residual_strides_[0];
auto stride_out_batch = out_strides_.size() > 1 ? out_strides_[0] : 0;
auto stride_out_nhead =
out_strides_.size() > 1 ? out_strides_[1] : out_strides_[0];
auto stride_residual_out_batch =
residual_out_strides_.size() > 1 ? residual_out_strides_[0] : 0;
auto stride_residual_out_nhead = residual_out_strides_.size() > 1
? residual_out_strides_[1]
: residual_out_strides_[0];

uint32_t num_blocks = static_cast<uint32_t>(batch_size_ * nhead_);

assert(out.dtype() == input.dtype() && out.dtype() == residual.dtype() &&
out.dtype() == weight.dtype() &&
out.dtype() == residual_out.dtype());

int block_size = RuntimeUtils<Backend::kDeviceType>::GetOptimalBlockSize();

DispatchFunc<ConcatType<List<DataType::kFloat32>, ReducedFloatTypes>,
AllCudaBlockSizes>(
{static_cast<int64_t>(out.dtype()), block_size},
[&](auto list_tag) {
using T = TypeMapType<Backend::kDeviceType, ListGet<0>(list_tag)>;
constexpr int kBlockSize = ListGet<1>(list_tag);

AddRmsNormKernel<kBlockSize, Backend::kDeviceType, float, T, T>
<<<num_blocks, kBlockSize, 0, cuda_stream>>>(
reinterpret_cast<T*>(out.data()), stride_out_batch,
stride_out_nhead, reinterpret_cast<T*>(residual_out.data()),
stride_residual_out_batch, stride_residual_out_nhead,
reinterpret_cast<const T*>(input.data()), stride_input_batch,
stride_input_nhead,
reinterpret_cast<const T*>(residual.data()),
stride_residual_batch, stride_residual_nhead,
reinterpret_cast<const T*>(weight.data()), nhead_, dim_, eps);
},
"CudaAddRmsNorm::operator()");
}
};

} // namespace infini::ops

#endif
Loading
Loading