From 9c9777037f7da6205e718c4a071d27bcaf69276d Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Thu, 4 Jun 2026 10:06:57 +0800 Subject: [PATCH] issue/1192: success awq_marlin_repack --- include/infiniop/ops/awq_marlin_repack.h | 27 ++ .../ops/awq_marlin_repack/awq_marlin_repack.h | 48 ++ .../ops/awq_marlin_repack/cuda/kernel.cuh | 197 ++++++++ src/infiniop/ops/awq_marlin_repack/info.h | 49 ++ .../ops/awq_marlin_repack/marlin/marlin.cuh | 178 +++++++ .../nvidia/awq_marlin_repack_nvidia.cu | 122 +++++ .../nvidia/awq_marlin_repack_nvidia.cuh | 8 + .../ops/awq_marlin_repack/operator.cc | 101 ++++ test/infiniop/awq_marlin_repack.py | 443 ++++++++++++++++++ test/infiniop/libinfiniop/op_register.py | 34 ++ 10 files changed, 1207 insertions(+) create mode 100644 include/infiniop/ops/awq_marlin_repack.h create mode 100644 src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h create mode 100644 src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh create mode 100644 src/infiniop/ops/awq_marlin_repack/info.h create mode 100644 src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh create mode 100644 src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu create mode 100644 src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh create mode 100644 src/infiniop/ops/awq_marlin_repack/operator.cc create mode 100644 test/infiniop/awq_marlin_repack.py diff --git a/include/infiniop/ops/awq_marlin_repack.h b/include/infiniop/ops/awq_marlin_repack.h new file mode 100644 index 000000000..017ff5568 --- /dev/null +++ b/include/infiniop/ops/awq_marlin_repack.h @@ -0,0 +1,27 @@ +#ifndef __INFINIOP_AWQ_MARLIN_REPACK_API_H__ +#define __INFINIOP_AWQ_MARLIN_REPACK_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopAwqMarlinRepackDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateAwqMarlinRepackDescriptor(infiniopHandle_t handle, + infiniopAwqMarlinRepackDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit); + +__INFINI_C __export infiniStatus_t infiniopGetAwqMarlinRepackWorkspaceSize(infiniopAwqMarlinRepackDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopAwqMarlinRepack(infiniopAwqMarlinRepackDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyAwqMarlinRepackDescriptor(infiniopAwqMarlinRepackDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h b/src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h new file mode 100644 index 000000000..e2768173f --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/awq_marlin_repack.h @@ -0,0 +1,48 @@ +#ifndef AWQ_MARLIN_REPACK_H +#define AWQ_MARLIN_REPACK_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::awq_marlin_repack::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AwqMarlinRepackInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + AwqMarlinRepackInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t output_desc, \ + infiniopTensorDescriptor_t input_desc, \ + int64_t num_bits, \ + bool is_a_8bit); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *output, \ + const void *input, \ + void *stream) const; \ + }; \ + } + +#endif // AWQ_MARLIN_REPACK_H diff --git a/src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh b/src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh new file mode 100644 index 000000000..b7288b603 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/cuda/kernel.cuh @@ -0,0 +1,197 @@ +#include "../marlin/marlin.cuh" + +namespace marlin { + +template +__device__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, + int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + constexpr int target_tile_n_size = tile_n_size / (is_a_8bit ? 2 : 1); + constexpr int target_tile_k_size = tile_k_size * (is_a_8bit ? 2 : 1); + int k_tiles = size_k / target_tile_k_size; + int n_tiles = size_n / target_tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = target_tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = target_tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * target_tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4 *sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * target_tile_k_size; + + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * (is_a_8bit ? 4 : 2); + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = (warp_id / (is_a_8bit ? 2 : 1)) * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4 *sh_stage_ptr = sh + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; +#pragma unroll + for (int i = 0; i < 4; i++) { + if constexpr (is_a_8bit) { + int cur_elem = tc_row + i; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) * (warp_id % 2) + sh_stride * (cur_elem + 16)]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } else { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + } + + constexpr int tile_size = target_tile_k_size * target_tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (!is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else if constexpr (is_a_8bit && num_bits == 4) { + int pack_idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + const int ii = is_a_8bit ? i : pack_idx[i]; + res1 |= vals[ii] << (i * 8); + res2 |= vals[4 + ii] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace marlin diff --git a/src/infiniop/ops/awq_marlin_repack/info.h b/src/infiniop/ops/awq_marlin_repack/info.h new file mode 100644 index 000000000..c9dea93fe --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/info.h @@ -0,0 +1,49 @@ +#ifndef __AWQ_MARLIN_REPACK_INFO_H__ +#define __AWQ_MARLIN_REPACK_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include "marlin/marlin.cuh" +#include + +#include + +namespace op::awq_marlin_repack { + +class AwqMarlinRepackInfo { + AwqMarlinRepackInfo() = default; + +public: + infiniDtype_t output_dtype, input_dtype; + size_t size_k, size_n; + int64_t num_bits; + bool is_a_8bit; + + static utils::Result create( + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit) { + CHECK_OR_RETURN( + output_desc != nullptr && input_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + const infiniDtype_t output_dtype = output_desc->dtype(); + const infiniDtype_t input_dtype = input_desc->dtype(); + CHECK_DTYPE(input_dtype, INFINI_DTYPE_I32); + CHECK_DTYPE(input_dtype, output_dtype); + + size_t size_k = input_desc->dim(0); + int const pack_factor = 32 / num_bits; + size_t size_n = input_desc->dim(1) * pack_factor; + + CHECK_OR_RETURN(size_k / marlin::tile_size == output_desc->dim(0) || size_n * marlin::tile_size / pack_factor == output_desc->dim(1), + INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result( + AwqMarlinRepackInfo{output_dtype, input_dtype, size_k, size_n, num_bits, is_a_8bit}); + } +}; + +} // namespace op::awq_marlin_repack + +#endif // __AWQ_MARLIN_REPACK_INFO_H__ diff --git a/src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh b/src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh new file mode 100644 index 000000000..3fbb4c463 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/marlin/marlin.cuh @@ -0,0 +1,178 @@ +#pragma once + +#ifndef _marlin_cuh +#define _marlin_cuh + +#include +#include +#include +#include + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +template +__device__ __forceinline__ uint32_t __cvta_generic_to_shared(T *ptr) { + size_t smem_addr; + asm volatile( + "cvta.to.shared.u64 %0, %1;" + : "=l"(smem_addr) + : "l"(ptr)); + return static_cast(smem_addr); +} + +namespace MARLIN_NAMESPACE_NAME { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; +} + +__device__ inline void cp_async_fence() {} + +template +__device__ inline void cp_async_wait() {} + +#else + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu new file mode 100644 index 000000000..ffa7ff44c --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cu @@ -0,0 +1,122 @@ +#if defined(ENABLE_NVIDIA_API) +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/kernel.cuh" +#include "awq_marlin_repack_nvidia.cuh" +#include + +template +INFINIOP_CUDA_KERNEL awqMarlinRepackKernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, + int size_k, int size_n) { + marlin::awq_marlin_repack_kernel( + b_q_weight_ptr, out_ptr, + size_k, size_n); +} + +#define CALL_IF(NUM_BITS, IS_A_8BIT) \ + else if (num_bits == NUM_BITS && is_a_8bit == IS_A_8BIT) { \ + cudaFuncSetAttribute( \ + awqMarlinRepackKernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + awqMarlinRepackKernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ + } + +infiniStatus_t awqMarlinRepack(uint32_t *out_ptr, const uint32_t *b_q_weight_ptr, int64_t size_k, + int64_t size_n, int64_t num_bits, + bool is_a_8bit, cudaStream_t stream) { + // Verify compatibility with marlin tile of 16x64 + if (size_k % marlin::tile_k_size != 0) { + std::cout << "size_k = " << size_k << " is not divisible by tile_k_size = " << marlin::tile_k_size << std::endl; + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (size_n % marlin::tile_n_size != 0) { + std::cout << "size_n = " << size_n << " is not divisible by tile_n_size = " << marlin::tile_n_size << std::endl; + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (num_bits != 4 && num_bits != 8) { + std::cout << "num_bits must be 4 or 8. Got = " << num_bits << std::endl; + return INFINI_STATUS_BAD_PARAM; + } + + int const pack_factor = 32 / num_bits; + + // Get dev info + int device_id = 0; + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id); + assert(max_shared_mem > 0 && "max_shared_mem must be greater than 0"); + + if (false) { + } + CALL_IF(4, false) + CALL_IF(8, false) + CALL_IF(4, true) + CALL_IF(8, true) + else { + assert(false && "Unsupported repack config: num_bits, is_a_8bit"); + } + + return INFINI_STATUS_SUCCESS; +} + +namespace op::awq_marlin_repack::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit) { + + auto handle = reinterpret_cast(handle_); + auto result = AwqMarlinRepackInfo::create(output_desc, input_desc, num_bits, is_a_8bit); + + size_t workspace_size = 0; + + *desc_ptr = new Descriptor( + new Opaque{handle->internal()}, + result.take(), + workspace_size, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, size_t workspace_size, + void *output, + const void *input, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + + int64_t size_k = static_cast(_info.size_k); + int64_t size_n = static_cast(_info.size_n); + int64_t num_bits = _info.num_bits; + bool is_a_8bit = _info.is_a_8bit; + + awqMarlinRepack((uint32_t *)output, (const uint32_t *)input, size_k, + size_n, num_bits, + is_a_8bit, stream); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::awq_marlin_repack::nvidia +#endif diff --git a/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh new file mode 100644 index 000000000..3cbec6c66 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/nvidia/awq_marlin_repack_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __AWQ_MARLIN_REPACK_CUDA_CUH__ +#define __AWQ_MARLIN_REPACK_CUDA_CUH__ + +#include "../awq_marlin_repack.h" + +DESCRIPTOR(nvidia) + +#endif // __AWQ_MARLIN_REPACK_CUDA_CUH__ diff --git a/src/infiniop/ops/awq_marlin_repack/operator.cc b/src/infiniop/ops/awq_marlin_repack/operator.cc new file mode 100644 index 000000000..ea02da87d --- /dev/null +++ b/src/infiniop/ops/awq_marlin_repack/operator.cc @@ -0,0 +1,101 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/awq_marlin_repack.h" + +#if defined ENABLE_NVIDIA_API +#include "nvidia/awq_marlin_repack_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateAwqMarlinRepackDescriptor( + infiniopHandle_t handle, + infiniopAwqMarlinRepackDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + int64_t num_bits, + bool is_a_8bit) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::awq_marlin_repack::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + input_desc, \ + num_bits, \ + is_a_8bit) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetAwqMarlinRepackWorkspaceSize(infiniopAwqMarlinRepackDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopAwqMarlinRepack( + infiniopAwqMarlinRepackDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *output, + const void *input, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, output, input, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyAwqMarlinRepackDescriptor(infiniopAwqMarlinRepackDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif diff --git a/test/infiniop/awq_marlin_repack.py b/test/infiniop/awq_marlin_repack.py new file mode 100644 index 000000000..3d2b5507e --- /dev/null +++ b/test/infiniop/awq_marlin_repack.py @@ -0,0 +1,443 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + TestWorkspace, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + to_torch_dtype, +) +import itertools +import numpy +from libinfiniop.scalar_type import scalar_types, ScalarType +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +import numpy as np + + +GPTQ_MARLIN_TILE = 16 +MARLIN_K_CHUNKS = [128] +MARLIN_N_CHUNKS = [64, 256] + +MARLIN_REPACK_NK_FACTORS = [ + (4, 8), + (7, 5), + (13, 11), +] + +def to_iter(x): + return x if isinstance(x, (list, tuple)) else (x,) + + +_TEST_CASES = list( + itertools.product( + to_iter(MARLIN_K_CHUNKS), + to_iter(MARLIN_N_CHUNKS), + to_iter([scalar_types.uint4]), + to_iter([True, False]), + to_iter(MARLIN_REPACK_NK_FACTORS), + to_iter([128]), + ) +) + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 3e-5, "rtol": 1e-5}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int | None, + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert quant_type.is_integer(), ( + "Floating point quantization may work but has not been tested" + ) + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) + +def get_weight_perm(num_bits: int, is_a_8bit: bool = False): + perm_list: list[int] = [] + if is_a_8bit: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + 4 * (i % 4 + 4), + 4 * (i % 4 + 4) + 1, + 4 * (i % 4 + 4) + 2, + 4 * (i % 4 + 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(2): + perm_list.extend([p + 512 * j for p in perm1]) + else: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7]) + else: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 1, 2, 3]) + else: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + +def marlin_permute_weights( + q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False +): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + if is_a_8bit: + # Permute weights to 32x32 marlin tiles + q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile)) + else: + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + +def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def awq_marlin_repack_torch(b_weight, size_k, size_n, group_size, quant_type, is_a_8bit): + # Quantize + w_ref, q_w, s, zp = quantize_weights( + b_weight, quant_type, group_size, zero_points=True + ) + + # Pack to AWQ format + q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) + + # Pack to Marlin format + weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit) + marlin_q_w_1 = marlin_weights( + q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit + ) + return marlin_q_w_1 + + +def test( + handle, + device, + k_chunk, + n_chunk, + quant_type, + is_a_8bit, + nk_factors, + group_size=128, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing awq_marlin_repack on {device} with k_chunk:{k_chunk}, n_chunk:{n_chunk}, is_a_8bit:{is_a_8bit}, nk_factors:{nk_factors}, group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + n_factor, k_factor = nk_factors + + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + + b_weight = TestTensor((size_k, size_n), None, dtype, device) + + w_ref, q_w, s, zp = quantize_weights( + b_weight.torch_tensor(), quant_type, group_size, zero_points=True + ) + + # Pack to AWQ format + q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n) + + ans = awq_marlin_repack_torch(b_weight.torch_tensor(), size_k, size_n, group_size, quant_type, is_a_8bit) + + input = TestTensor( + q_w_awq.shape, + q_w_awq.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=q_w_awq, + ) + output = TestTensor(ans.shape, None, InfiniDtype.I32, device, mode="zeros") + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateAwqMarlinRepackDescriptor( + handle, + ctypes.byref(descriptor), + output.descriptor, + input.descriptor, + quant_type.size_bits, + is_a_8bit, + ) + ) + + # Invalidate descriptors (same pattern as other tests) + for tensor in [ + output, + input, + ]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetAwqMarlinRepackWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_awq_marlin_repack(): + check_error( + LIBINFINIOP.infiniopAwqMarlinRepack( + descriptor, + workspace.data(), + workspace_size.value, + output.data(), + input.data(), + None, + ) + ) + + lib_awq_marlin_repack() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(output.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(output.actual_tensor(), ans, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: awq_marlin_repack_torch(b_weight.torch_tensor(), size_k, size_n, group_size, quant_type, is_a_8bit), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", + lambda: lib_awq_marlin_repack(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + + check_error(LIBINFINIOP.infiniopDestroyAwqMarlinRepackDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ad41f7d23..47f3beb26 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1453,6 +1453,40 @@ def awq_marlin_gemm_(lib): ] +@OpRegister.operator +def awq_marlin_repack_(lib): + lib.infiniopCreateAwqMarlinRepackDescriptor.restype = c_int32 + lib.infiniopCreateAwqMarlinRepackDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int64, + c_bool, + ] + + lib.infiniopGetAwqMarlinRepackWorkspaceSize.restype = c_int32 + lib.infiniopGetAwqMarlinRepackWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopAwqMarlinRepack.restype = c_int32 + lib.infiniopAwqMarlinRepack.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyAwqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopDestroyAwqMarlinGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32