diff --git a/src/base/add_rms_norm.h b/src/base/add_rms_norm.h index 3c8889176..51dbefafa 100644 --- a/src/base/add_rms_norm.h +++ b/src/base/add_rms_norm.h @@ -13,28 +13,53 @@ class AddRmsNorm : public Operator { public: // TODO: Make `eps` an `std::optional` with a PyTorch-aligned default. // Also consider the same change for `RmsNorm`. - AddRmsNorm(const Tensor input, const Tensor other, const Tensor weight, - float eps, Tensor out, Tensor rstd_out) + AddRmsNorm(const Tensor input, const Tensor residual, const Tensor weight, + float eps, Tensor out, Tensor residual_out) : input_shape_{input.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + residual_strides_{residual.strides()}, + out_strides_{out.strides()}, + residual_out_strides_{residual_out.strides()}, eps_{eps}, - dim_{input.size(-1)}, - ndim_{input.ndim()}, - batch_size_{ndim_ == 2 ? input.size(-2) : input.size(-3)}, - nhead_{ndim_ == 2 ? 1 : input.size(-2)}, - rstd_shape_{static_cast(batch_size_), - static_cast(nhead_)} { - assert(input.dtype() == other.dtype()); + dim_{out.size(-1)}, + ndim_{out.ndim()}, + batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)}, + nhead_{ndim_ == 2 ? 1 : out.size(-2)} { + assert(ndim_ == 2 || ndim_ == 3); + assert(input.shape() == out.shape()); + assert(input.shape() == residual.shape()); + assert(input.shape() == residual_out.shape()); + assert(weight.ndim() == 1 && weight.size(-1) == dim_); assert(input.dtype() == out.dtype()); - assert(input.dtype() == rstd_out.dtype()); + assert(input.dtype() == residual.dtype()); + assert(input.dtype() == residual_out.dtype()); + assert(input.dtype() == weight.dtype()); + // CUDA kernel indexes the normalized dimension with stride 1. + assert(input.stride(-1) == 1); + assert(residual.stride(-1) == 1); + assert(out.stride(-1) == 1); + assert(residual_out.stride(-1) == 1); + assert(weight.stride(-1) == 1); } - virtual void operator()(const Tensor input, const Tensor other, + virtual void operator()(const Tensor input, const Tensor residual, const Tensor weight, float eps, Tensor out, - Tensor rstd_out) const = 0; + Tensor residual_out) const = 0; protected: Tensor::Shape input_shape_; + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides residual_strides_; + + Tensor::Strides out_strides_; + + Tensor::Strides residual_out_strides_; + float eps_{1e-6f}; Tensor::Size dim_{0}; @@ -44,8 +69,6 @@ class AddRmsNorm : public Operator { Tensor::Size batch_size_{0}; Tensor::Size nhead_{1}; - - std::vector rstd_shape_; }; } // namespace infini::ops diff --git a/src/native/cuda/nvidia/ops/add_rms_norm/kernel.h b/src/native/cuda/nvidia/ops/add_rms_norm/kernel.h new file mode 100644 index 000000000..2bb6f6051 --- /dev/null +++ b/src/native/cuda/nvidia/ops/add_rms_norm/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_NVIDIA_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_NVIDIA_ADD_RMS_NORM_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/add_rms_norm/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaAddRmsNorm> { + public: + using CudaAddRmsNorm>::CudaAddRmsNorm; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/add_rms_norm/kernel.cuh b/src/native/cuda/ops/add_rms_norm/kernel.cuh new file mode 100644 index 000000000..7fd8215e1 --- /dev/null +++ b/src/native/cuda/ops/add_rms_norm/kernel.cuh @@ -0,0 +1,80 @@ +#ifndef INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_CUH_ +#define INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_CUH_ + +#include +#include +#include + +#include "native/cuda/caster.cuh" +#include "native/cuda/kernel_commons.cuh" + +namespace infini::ops { +namespace add_rms_norm_detail { + +// Same as `native/cuda/ops/rms_norm/kernel.cuh`. +template +__device__ __forceinline__ TCompute SumSquared(const TData* data_ptr, + size_t count) { + TCompute ss = 0; + for (size_t i = threadIdx.x; i < count; i += block_size) { + TCompute value = Caster::template Cast(data_ptr[i]); + ss += value * value; + } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + return BlockReduce(temp_storage).Sum(ss); +} + +} // namespace add_rms_norm_detail + +template +__global__ void AddRmsNormKernel( + TData* __restrict__ y, int64_t stride_y_batch, int64_t stride_y_nhead, + TData* __restrict__ residual_out, int64_t stride_residual_out_batch, + int64_t stride_residual_out_nhead, const TData* __restrict__ input, + int64_t stride_input_batch, int64_t stride_input_nhead, + const TData* __restrict__ residual, int64_t stride_residual_batch, + int64_t stride_residual_nhead, const TWeight* __restrict__ w, size_t nhead, + size_t dim, float epsilon) { + size_t batch_idx = blockIdx.x / nhead; + size_t head_idx = blockIdx.x % nhead; + + auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead; + auto input_ptr = + input + batch_idx * stride_input_batch + head_idx * stride_input_nhead; + auto residual_ptr = residual + batch_idx * stride_residual_batch + + head_idx * stride_residual_nhead; + auto w_ptr = w; + auto residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + + head_idx * stride_residual_out_nhead; + + for (size_t i = threadIdx.x; i < dim; i += block_size) { + TCompute sum_val = Caster::template Cast(input_ptr[i]) + + Caster::template Cast(residual_ptr[i]); + residual_out_ptr[i] = Caster::template Cast(sum_val); + } + + TCompute sum_squared = + add_rms_norm_detail::SumSquared( + residual_out_ptr, dim); + + __shared__ TCompute rms; + if (threadIdx.x == 0) { + rms = Caster::template Cast(rsqrtf( + sum_squared / Caster::template Cast(dim) + epsilon)); + } + __syncthreads(); + + for (size_t i = threadIdx.x; i < dim; i += block_size) { + TCompute sum_val = + Caster::template Cast(residual_out_ptr[i]); + y_ptr[i] = Caster::template Cast( + sum_val * Caster::template Cast(w_ptr[i]) * rms); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/add_rms_norm/kernel.h b/src/native/cuda/ops/add_rms_norm/kernel.h new file mode 100644 index 000000000..d0f87a7e3 --- /dev/null +++ b/src/native/cuda/ops/add_rms_norm/kernel.h @@ -0,0 +1,76 @@ +#ifndef INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_CUDA_ADD_RMS_NORM_KERNEL_H_ + +#include +#include + +#include "base/add_rms_norm.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/add_rms_norm/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaAddRmsNorm : public AddRmsNorm { + public: + using AddRmsNorm::AddRmsNorm; + + void operator()(const Tensor input, const Tensor residual, + const Tensor weight, float eps, Tensor out, + Tensor residual_out) const override { + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + + auto stride_input_batch = input_strides_.size() > 1 ? input_strides_[0] : 0; + auto stride_input_nhead = + input_strides_.size() > 1 ? input_strides_[1] : input_strides_[0]; + auto stride_residual_batch = + residual_strides_.size() > 1 ? residual_strides_[0] : 0; + auto stride_residual_nhead = residual_strides_.size() > 1 + ? residual_strides_[1] + : residual_strides_[0]; + auto stride_out_batch = out_strides_.size() > 1 ? out_strides_[0] : 0; + auto stride_out_nhead = + out_strides_.size() > 1 ? out_strides_[1] : out_strides_[0]; + auto stride_residual_out_batch = + residual_out_strides_.size() > 1 ? residual_out_strides_[0] : 0; + auto stride_residual_out_nhead = residual_out_strides_.size() > 1 + ? residual_out_strides_[1] + : residual_out_strides_[0]; + + uint32_t num_blocks = static_cast(batch_size_ * nhead_); + + assert(out.dtype() == input.dtype() && out.dtype() == residual.dtype() && + out.dtype() == weight.dtype() && + out.dtype() == residual_out.dtype()); + + int block_size = RuntimeUtils::GetOptimalBlockSize(); + + DispatchFunc, ReducedFloatTypes>, + AllCudaBlockSizes>( + {static_cast(out.dtype()), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + AddRmsNormKernel + <<>>( + reinterpret_cast(out.data()), stride_out_batch, + stride_out_nhead, reinterpret_cast(residual_out.data()), + stride_residual_out_batch, stride_residual_out_nhead, + reinterpret_cast(input.data()), stride_input_batch, + stride_input_nhead, + reinterpret_cast(residual.data()), + stride_residual_batch, stride_residual_nhead, + reinterpret_cast(weight.data()), nhead_, dim_, eps); + }, + "CudaAddRmsNorm::operator()"); + } +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_add_rms_norm.py b/tests/test_add_rms_norm.py new file mode 100644 index 000000000..1d94f7455 --- /dev/null +++ b/tests/test_add_rms_norm.py @@ -0,0 +1,199 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +# Format: (input_shape, weight_shape, input_strides, residual_strides, weight_strides, out_strides); +# input/residual/residual_out share input_shape. +_TEST_CASES = ( + ((1, 4), (4,), None, None, None, None), + ((2, 4), (4,), None, None, None, None), + ((2, 2, 4), (4,), None, None, None, None), + ((2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), None, (12, 8, 1)), + ((16, 2048), (2048,), None, None, None, None), + ((16, 2048), (2048,), (4096, 1), (4096, 1), None, (4096, 1)), + ((15, 3584), (3584,), None, None, None, None), + ((4, 4, 2048), (2048,), None, None, None, None), + ((4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), None, (2048, 8192, 1)), + ( + (4, 4, 2048), + (2048,), + (16384, 4096, 1), + (16384, 4096, 1), + None, + (16384, 4096, 1), + ), +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize("check_output", ("out", "residual_out")) +@pytest.mark.parametrize( + "input_shape, weight_shape, input_strides, residual_strides, weight_strides, out_strides", + _TEST_CASES, +) +@pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-4, 1e-4), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 2e-2, 1e-2), + ), +) +def test_add_rms_norm( + check_output, + input_shape, + weight_shape, + input_strides, + residual_strides, + weight_strides, + out_strides, + eps, + implementation_index, + dtype, + device, + rtol, + atol, +): + input = randn_strided(input_shape, input_strides, dtype=dtype, device=device) + residual = randn_strided(input_shape, residual_strides, dtype=dtype, device=device) + weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) + residual_out = empty_strided(input_shape, input_strides, dtype=dtype, device=device) + out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) + + if check_output == "out": + func = _add_rms_norm_out + ref = _torch_add_rms_norm_out + else: + func = _add_rms_norm_residual_out + ref = _torch_add_rms_norm_residual_out + + return Payload( + lambda *args, **kwargs: func( + *args, **kwargs, implementation_index=implementation_index + ), + ref, + (input, residual, weight), + {"eps": eps, "out": out, "residual_out": residual_out}, + rtol=rtol, + atol=atol, + ) + + +def _add_rms_norm( + input, + residual, + weight, + *, + eps=1e-6, + out=None, + residual_out=None, + implementation_index=0, +): + infini.ops.add_rms_norm( + input, + residual, + weight, + eps, + out, + residual_out, + implementation_index=implementation_index, + stream=get_stream(input.device), + ) + + +def _add_rms_norm_out( + input, + residual, + weight, + *, + eps=1e-6, + out=None, + residual_out=None, + implementation_index=0, +): + _add_rms_norm( + input, + residual, + weight, + eps=eps, + out=out, + residual_out=residual_out, + implementation_index=implementation_index, + ) + + return out + + +def _add_rms_norm_residual_out( + input, + residual, + weight, + *, + eps=1e-6, + out=None, + residual_out=None, + implementation_index=0, +): + _add_rms_norm( + input, + residual, + weight, + eps=eps, + out=out, + residual_out=residual_out, + implementation_index=implementation_index, + ) + + return residual_out + + +def _torch_add_rms_norm( + input, residual, weight, *, eps=1e-6, out=None, residual_out=None +): + """Reference aligned with vLLM `fused_add_rms_norm` (ignoring `variance_size`).""" + orig_dtype = input.dtype + x = input.to(torch.float32) + x = x + residual.to(torch.float32) + add_result = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + if weight is not None: + x = x.to(weight.dtype) * weight + normalized_result = x.to(orig_dtype) + + if out is not None: + out.copy_(normalized_result) + else: + out = normalized_result + + if residual_out is not None: + residual_out.copy_(add_result) + else: + residual_out = add_result + + return out, residual_out + + +def _torch_add_rms_norm_out( + input, residual, weight, *, eps=1e-6, out=None, residual_out=None +): + out, _ = _torch_add_rms_norm( + input, residual, weight, eps=eps, out=out, residual_out=residual_out + ) + + return out + + +def _torch_add_rms_norm_residual_out( + input, residual, weight, *, eps=1e-6, out=None, residual_out=None +): + _, residual_out = _torch_add_rms_norm( + input, residual, weight, eps=eps, out=out, residual_out=residual_out + ) + + return residual_out