diff --git a/src/base/silu.h b/src/base/silu.h index 090c4f0a0..185e5ceac 100644 --- a/src/base/silu.h +++ b/src/base/silu.h @@ -1,37 +1,60 @@ #ifndef INFINI_OPS_BASE_SILU_H_ #define INFINI_OPS_BASE_SILU_H_ +#include + +#include "data_type.h" #include "operator.h" namespace infini::ops { +// Aligned with InfiniCore and `torch.nn.functional.silu`: + class Silu : public Operator { public: Silu(const Tensor input, Tensor out) - : input_shape_{input.shape()}, - input_strides_{input.strides()}, + : ndim_{out.ndim()}, + output_size_{out.numel()}, input_type_{input.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, out_shape_{out.shape()}, + input_strides_{input.strides()}, out_strides_{out.strides()}, - out_type_{out.dtype()}, - device_index_{out.device().index()} {} + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(input.shape() == out.shape() && + "`Silu` requires `input` and `out` to have the same shape"); + assert(input_type_ == out_type_ && + "`Silu` requires `input` and `out` to have the same dtype"); + assert(input_type_ == DataType::kFloat16 || + input_type_ == DataType::kBFloat16 || + input_type_ == DataType::kFloat32 || + input_type_ == DataType::kFloat64); + } virtual void operator()(const Tensor input, Tensor out) const = 0; protected: - Tensor::Shape input_shape_; + Tensor::Size ndim_{0}; - Tensor::Strides input_strides_; + Tensor::Size output_size_{0}; DataType input_type_; + DataType out_type_; + + Tensor::Shape input_shape_; + Tensor::Shape out_shape_; + Tensor::Strides input_strides_; + Tensor::Strides out_strides_; - DataType out_type_; + bool is_input_contiguous_{false}; - int device_index_{0}; + bool is_out_contiguous_{false}; }; } // namespace infini::ops diff --git a/src/native/cuda/nvidia/ops/silu/kernel.h b/src/native/cuda/nvidia/ops/silu/kernel.h new file mode 100644 index 000000000..7f5001f71 --- /dev/null +++ b/src/native/cuda/nvidia/ops/silu/kernel.h @@ -0,0 +1,21 @@ +#ifndef INFINI_OPS_NVIDIA_SILU_KERNEL_H_ +#define INFINI_OPS_NVIDIA_SILU_KERNEL_H_ + +#include + +#include "native/cuda/nvidia/caster.cuh" +#include "native/cuda/nvidia/runtime_.h" +#include "native/cuda/ops/silu/kernel.h" + +namespace infini::ops { + +template <> +class Operator + : public CudaSilu> { + public: + using CudaSilu>::CudaSilu; +}; + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/silu/kernel.cuh b/src/native/cuda/ops/silu/kernel.cuh new file mode 100644 index 000000000..330f2bec3 --- /dev/null +++ b/src/native/cuda/ops/silu/kernel.cuh @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_CUDA_SILU_KERNEL_CUH_ +#define INFINI_OPS_CUDA_SILU_KERNEL_CUH_ + +#include + +#include "native/cuda/kernel_commons.cuh" + +namespace infini::ops { +namespace silu_detail { + +// Same semantics as `third_party/InfiniCore/.../silu/cuda/kernel.cuh::SiluOp`. +template +__device__ __forceinline__ T Silu(const T& x) { + if constexpr (IsFP16 || IsBFloat16) { + float xf = Caster::template Cast(x); + float sigf = __frcp_rn(__fadd_rn(1.0f, __expf(-xf))); + return Caster::template Cast(__fmul_rn(xf, sigf)); + } else if constexpr (std::is_same_v) { + return __fmul_rn(x, __frcp_rn(__fadd_rn(1.0f, __expf(-x)))); + } else { + return x / (T{1} + exp(-x)); + } +} + +} // namespace silu_detail + +template +__global__ void SiluKernel(T* __restrict__ out, const T* __restrict__ input, + const size_t* __restrict__ out_shape, + const size_t* __restrict__ input_shape, + const ptrdiff_t* __restrict__ out_strides, + const ptrdiff_t* __restrict__ input_strides, + size_t output_size, size_t ndim, bool out_contiguous, + bool input_contiguous) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (idx >= output_size) { + return; + } + + size_t out_idx = + out_contiguous ? idx : IndexToOffset(idx, ndim, out_shape, out_strides); + size_t input_idx = input_contiguous + ? idx + : IndexToOffset(idx, ndim, input_shape, input_strides); + + out[out_idx] = silu_detail::Silu(input[input_idx]); +} + +} // namespace infini::ops + +#endif diff --git a/src/native/cuda/ops/silu/kernel.h b/src/native/cuda/ops/silu/kernel.h new file mode 100644 index 000000000..a3bc46689 --- /dev/null +++ b/src/native/cuda/ops/silu/kernel.h @@ -0,0 +1,96 @@ +#ifndef INFINI_OPS_CUDA_SILU_KERNEL_H_ +#define INFINI_OPS_CUDA_SILU_KERNEL_H_ + +#include +#include +#include +#include + +#include "base/silu.h" +#include "common/generic_utils.h" +#include "data_type.h" +#include "dispatcher.h" +#include "native/cuda/kernel_commons.cuh" +#include "native/cuda/ops/silu/kernel.cuh" +#include "native/cuda/runtime_utils.h" + +namespace infini::ops { + +template +class CudaSilu : public Silu { + public: + CudaSilu(const Tensor input, Tensor out) : Silu{input, out} { + size_t shape_size = ndim_ * sizeof(*d_input_shape_); + size_t strides_size = ndim_ * sizeof(*d_input_strides_); + const size_t metadata_size = 2 * (shape_size + strides_size); + std::vector metadata(metadata_size); + + Backend::Malloc((void**)&d_metadata_, metadata_size); + + size_t offset = 0; + d_input_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_shape_.data(), shape_size); + offset += shape_size; + + d_out_shape_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_shape_.data(), shape_size); + offset += shape_size; + + d_input_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, input_strides_.data(), strides_size); + offset += strides_size; + + d_out_strides_ = reinterpret_cast(d_metadata_ + offset); + std::memcpy(metadata.data() + offset, out_strides_.data(), strides_size); + + Backend::Memcpy(d_metadata_, metadata.data(), metadata_size, + Backend::MemcpyHostToDevice); + } + + ~CudaSilu() { Backend::Free(d_metadata_); } + + void operator()(const Tensor input, Tensor out) const override { + if (output_size_ == 0) { + return; + } + + int block_size = RuntimeUtils::GetOptimalBlockSize(); + DispatchFunc( + {static_cast(out_type_), block_size}, + [&](auto list_tag) { + using T = TypeMapType(list_tag)>; + constexpr int kBlockSize = ListGet<1>(list_tag); + + auto cuda_stream = + static_cast(stream_ ? stream_ : 0); + dim3 blockDims( + std::min(static_cast(block_size), output_size_)); + dim3 gridDims(utils::CeilDiv(output_size_, blockDims.x)); + + T* d_out = reinterpret_cast(out.data()); + const T* d_input = reinterpret_cast(input.data()); + + SiluKernel + <<>>( + d_out, d_input, d_out_shape_, d_input_shape_, d_out_strides_, + d_input_strides_, output_size_, ndim_, is_out_contiguous_, + is_input_contiguous_); + }, + "CudaSilu::operator()"); + } + + private: + std::byte* d_metadata_{nullptr}; + + Tensor::Size* d_input_shape_{nullptr}; + + Tensor::Size* d_out_shape_{nullptr}; + + Tensor::Stride* d_input_strides_{nullptr}; + + Tensor::Stride* d_out_strides_{nullptr}; +}; + +} // namespace infini::ops + +#endif diff --git a/tests/test_silu.py b/tests/test_silu.py new file mode 100644 index 000000000..bb204aee8 --- /dev/null +++ b/tests/test_silu.py @@ -0,0 +1,69 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((2, 4), None, None), + ((128, 64), None, None), + ((2, 4, 8), None, None), + ((4, 48, 6), None, None), + ((1, 2048), (4096, 1), (4096, 1)), + ((8, 16, 32), None, None), + ((16, 5632), None, None), + ((4, 4, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_silu( + shape, + input_strides, + out_strides, + implementation_index, + dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload( + lambda *args, **kwargs: _silu( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_silu, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _silu(input, out, implementation_index=0): + infini.ops.silu( + input, + out, + implementation_index=implementation_index, + stream=get_stream(input.device), + ) + + return out + + +def _torch_silu(input, out): + out.copy_(input * torch.sigmoid(input)) + + return out