diff --git a/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu index 5029e3903752..37a612b43944 100644 --- a/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu +++ b/csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu @@ -156,6 +156,17 @@ inline int GetGroupsPerBlock(int64_t num_groups) { return 1; } +// Largest divisor of padded_groups_per_row that is <= 16. ry = 16 / kx. +inline int GetGroupsPerBlockX(int64_t padded_groups_per_row) { + if (padded_groups_per_row % 16 == 0) { + return 16; + } + if (padded_groups_per_row % 8 == 0) { + return 8; + } + return 4; +} + void per_token_group_quant_8bit(const torch::stable::Tensor& input, torch::stable::Tensor& output_q, torch::stable::Tensor& output_s, @@ -247,11 +258,11 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input, // // Constraints: GROUP_SIZE % (THREADS_PER_GROUP * VEC_SIZE) == 0; for // THREADS_PER_GROUP=8 and bf16/fp16 (VEC_SIZE=16), this means GROUP_SIZE=128. -template +template __global__ void per_token_group_quant_8bit_packed_register_kernel( const T* __restrict__ input, void* __restrict__ output_q, - unsigned int* __restrict__ output_s_packed, const int64_t num_groups_padded, - const int groups_per_block, const int padded_groups_per_row, + unsigned int* __restrict__ output_s_packed, const int padded_groups_per_row, const int groups_per_row, const int mn, const int output_q_mn_extent, const int tma_aligned_mn, const int64_t num_scale_elems, const float eps, const float min_8bit, const float max_8bit) { @@ -260,27 +271,25 @@ __global__ void per_token_group_quant_8bit_packed_register_kernel( constexpr int VEC_SIZE = 32 / sizeof(T); // 16 for bf16/fp16 static_assert(GROUP_SIZE == THREADS_PER_GROUP * VEC_SIZE, "GROUP_SIZE must equal THREADS_PER_GROUP * VEC_SIZE"); - // Each group's 8 threads must live in a single warp octet so the - // 0xffu << (threadIdx.x & 24u) shuffle mask selects exactly the lanes - // that share a group. Requires 32 % THREADS_PER_GROUP == 0 and the host - // to launch num_threads as a multiple of THREADS_PER_GROUP (which it does - // via num_threads = groups_per_block * THREADS_PER_GROUP). static_assert(32 % THREADS_PER_GROUP == 0, "THREADS_PER_GROUP must divide warp size for the shuffle " "mask to be valid"); + static_assert( + kGroupsPerBlockX > 0 && (kGroupsPerBlockX & (kGroupsPerBlockX - 1)) == 0, + "kGroupsPerBlockX must be a positive power of 2"); + static_assert(kRowsPerBlock > 0, "kRowsPerBlock must be positive"); const int local_group_id = threadIdx.x / THREADS_PER_GROUP; const int lane_id = threadIdx.x % THREADS_PER_GROUP; - const int64_t block_group_id = blockIdx.x * groups_per_block; - const int64_t global_group_id = block_group_id + local_group_id; - if (global_group_id >= num_groups_padded) { + const int sf_k_local = local_group_id % kGroupsPerBlockX; + const int row_local = local_group_id / kGroupsPerBlockX; + const int sf_k_idx = blockIdx.x * kGroupsPerBlockX + sf_k_local; + const int mn_idx = blockIdx.y * kRowsPerBlock + row_local; + + if (mn_idx >= tma_aligned_mn) { return; } - - const int sf_k_idx = - static_cast(global_group_id % padded_groups_per_row); - const int mn_idx = static_cast(global_group_id / padded_groups_per_row); const bool is_valid_group = (mn_idx < mn) && (sf_k_idx < groups_per_row); // Load 16 input elements (32 B) into registers as two adjacent uint4 @@ -443,34 +452,53 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, constexpr int THREADS_PER_GROUP = 8; const int64_t padded_groups_per_row = k_num_packed_sfk * 4; - const int64_t num_groups_padded = tma_aligned_mn * padded_groups_per_row; const int64_t num_scale_elems = mn + (k_num_packed_sfk - 1) * tma_aligned_mn; - const int groups_per_block = GetGroupsPerBlock(num_groups_padded); + + STD_TORCH_CHECK(padded_groups_per_row % 4 == 0, + "padded_groups_per_row=", padded_groups_per_row, + " is not a multiple of 4."); + const int kx = GetGroupsPerBlockX(padded_groups_per_row); + const int ry = 16 / kx; + const int64_t blocks_x = padded_groups_per_row / kx; + const int64_t blocks_y = (tma_aligned_mn + ry - 1) / ry; + const int num_threads = (kx * ry) * THREADS_PER_GROUP; + // CUDA caps grid.x and grid.y at 2^31 - 1; guard against pathological inputs. + STD_TORCH_CHECK(blocks_x <= static_cast(INT32_MAX) && + blocks_y <= static_cast(INT32_MAX), + "per_token_group_quant_8bit_packed grid too large: (", + blocks_x, ", ", blocks_y, ")."); auto dst_type = output_q.scalar_type(); - const int64_t num_blocks = num_groups_padded / groups_per_block; - const int num_threads = groups_per_block * THREADS_PER_GROUP; - // CUDA caps grid.x at 2^31 - 1; this fits any realistic shape but guard - // against pathological inputs. - STD_TORCH_CHECK(num_blocks <= static_cast(INT32_MAX), - "per_token_group_quant_8bit_packed grid too large: ", - num_blocks, " blocks (max ", INT32_MAX, ")."); - -#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \ - do { \ - dim3 grid(static_cast(num_blocks)); \ - dim3 block(num_threads); \ - per_token_group_quant_8bit_packed_register_kernel \ - <<>>( \ - static_cast(input.data_ptr()), output_q.data_ptr(), \ - reinterpret_cast(output_s_packed.data_ptr()), \ - num_groups_padded, groups_per_block, \ - static_cast(padded_groups_per_row), \ - static_cast(groups_per_row), static_cast(mn), \ - static_cast(output_q_mn_extent), \ - static_cast(tma_aligned_mn), num_scale_elems, \ - static_cast(eps), static_cast(min_8bit), \ - static_cast(max_8bit)); \ + +#define LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, KX, RY) \ + do { \ + dim3 grid(static_cast(blocks_x), \ + static_cast(blocks_y)); \ + dim3 block(num_threads); \ + per_token_group_quant_8bit_packed_register_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + reinterpret_cast(output_s_packed.data_ptr()), \ + static_cast(padded_groups_per_row), \ + static_cast(groups_per_row), static_cast(mn), \ + static_cast(output_q_mn_extent), \ + static_cast(tma_aligned_mn), num_scale_elems, \ + static_cast(eps), static_cast(min_8bit), \ + static_cast(max_8bit)); \ + } while (0) + +#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \ + do { \ + if (kx == 16) { \ + LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 16, 1); \ + } else if (kx == 8) { \ + LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 8, 2); \ + } else if (kx == 4) { \ + LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 4, 4); \ + } else { \ + STD_TORCH_CHECK(false, "Unsupported kx value ", kx); \ + } \ } while (0) VLLM_STABLE_DISPATCH_HALF_TYPES( @@ -488,6 +516,7 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input, })); #undef LAUNCH_REG_KERNEL +#undef LAUNCH_REG_KERNEL_INST } void per_token_group_quant_fp8(const torch::stable::Tensor& input,