Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 69 additions & 40 deletions csrc/libtorch_stable/quantization/w8a8/fp8/per_token_group_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ inline int GetGroupsPerBlock(int64_t num_groups) {
return 1;
}

// Largest divisor of padded_groups_per_row that is <= 16. ry = 16 / kx.
inline int GetGroupsPerBlockX(int64_t padded_groups_per_row) {
if (padded_groups_per_row % 16 == 0) {
return 16;
}
if (padded_groups_per_row % 8 == 0) {
return 8;
}
return 4;
}

void per_token_group_quant_8bit(const torch::stable::Tensor& input,
torch::stable::Tensor& output_q,
torch::stable::Tensor& output_s,
Expand Down Expand Up @@ -247,11 +258,11 @@ void per_token_group_quant_8bit(const torch::stable::Tensor& input,
//
// Constraints: GROUP_SIZE % (THREADS_PER_GROUP * VEC_SIZE) == 0; for
// THREADS_PER_GROUP=8 and bf16/fp16 (VEC_SIZE=16), this means GROUP_SIZE=128.
template <typename T, typename DST_DTYPE, int GROUP_SIZE>
template <typename T, typename DST_DTYPE, int GROUP_SIZE, int kGroupsPerBlockX,
int kRowsPerBlock>
__global__ void per_token_group_quant_8bit_packed_register_kernel(
const T* __restrict__ input, void* __restrict__ output_q,
unsigned int* __restrict__ output_s_packed, const int64_t num_groups_padded,
const int groups_per_block, const int padded_groups_per_row,
unsigned int* __restrict__ output_s_packed, const int padded_groups_per_row,
const int groups_per_row, const int mn, const int output_q_mn_extent,
const int tma_aligned_mn, const int64_t num_scale_elems, const float eps,
const float min_8bit, const float max_8bit) {
Expand All @@ -260,27 +271,25 @@ __global__ void per_token_group_quant_8bit_packed_register_kernel(
constexpr int VEC_SIZE = 32 / sizeof(T); // 16 for bf16/fp16
static_assert(GROUP_SIZE == THREADS_PER_GROUP * VEC_SIZE,
"GROUP_SIZE must equal THREADS_PER_GROUP * VEC_SIZE");
// Each group's 8 threads must live in a single warp octet so the
// 0xffu << (threadIdx.x & 24u) shuffle mask selects exactly the lanes
// that share a group. Requires 32 % THREADS_PER_GROUP == 0 and the host
// to launch num_threads as a multiple of THREADS_PER_GROUP (which it does
// via num_threads = groups_per_block * THREADS_PER_GROUP).
static_assert(32 % THREADS_PER_GROUP == 0,
"THREADS_PER_GROUP must divide warp size for the shuffle "
"mask to be valid");
static_assert(
kGroupsPerBlockX > 0 && (kGroupsPerBlockX & (kGroupsPerBlockX - 1)) == 0,
"kGroupsPerBlockX must be a positive power of 2");
static_assert(kRowsPerBlock > 0, "kRowsPerBlock must be positive");

const int local_group_id = threadIdx.x / THREADS_PER_GROUP;
const int lane_id = threadIdx.x % THREADS_PER_GROUP;

const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
if (global_group_id >= num_groups_padded) {
const int sf_k_local = local_group_id % kGroupsPerBlockX;
const int row_local = local_group_id / kGroupsPerBlockX;
const int sf_k_idx = blockIdx.x * kGroupsPerBlockX + sf_k_local;
const int mn_idx = blockIdx.y * kRowsPerBlock + row_local;

if (mn_idx >= tma_aligned_mn) {
return;
}

const int sf_k_idx =
static_cast<int>(global_group_id % padded_groups_per_row);
const int mn_idx = static_cast<int>(global_group_id / padded_groups_per_row);
const bool is_valid_group = (mn_idx < mn) && (sf_k_idx < groups_per_row);

// Load 16 input elements (32 B) into registers as two adjacent uint4
Expand Down Expand Up @@ -443,34 +452,53 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,

constexpr int THREADS_PER_GROUP = 8;
const int64_t padded_groups_per_row = k_num_packed_sfk * 4;
const int64_t num_groups_padded = tma_aligned_mn * padded_groups_per_row;
const int64_t num_scale_elems = mn + (k_num_packed_sfk - 1) * tma_aligned_mn;
const int groups_per_block = GetGroupsPerBlock(num_groups_padded);

STD_TORCH_CHECK(padded_groups_per_row % 4 == 0,
"padded_groups_per_row=", padded_groups_per_row,
" is not a multiple of 4.");
const int kx = GetGroupsPerBlockX(padded_groups_per_row);
const int ry = 16 / kx;
const int64_t blocks_x = padded_groups_per_row / kx;
const int64_t blocks_y = (tma_aligned_mn + ry - 1) / ry;
const int num_threads = (kx * ry) * THREADS_PER_GROUP;
// CUDA caps grid.x and grid.y at 2^31 - 1; guard against pathological inputs.
STD_TORCH_CHECK(blocks_x <= static_cast<int64_t>(INT32_MAX) &&
blocks_y <= static_cast<int64_t>(INT32_MAX),
"per_token_group_quant_8bit_packed grid too large: (",
blocks_x, ", ", blocks_y, ").");

auto dst_type = output_q.scalar_type();
const int64_t num_blocks = num_groups_padded / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP;
// CUDA caps grid.x at 2^31 - 1; this fits any realistic shape but guard
// against pathological inputs.
STD_TORCH_CHECK(num_blocks <= static_cast<int64_t>(INT32_MAX),
"per_token_group_quant_8bit_packed grid too large: ",
num_blocks, " blocks (max ", INT32_MAX, ").");

#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
do { \
dim3 grid(static_cast<unsigned int>(num_blocks)); \
dim3 block(num_threads); \
per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128> \
<<<grid, block, 0, stream>>>( \
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
num_groups_padded, groups_per_block, \
static_cast<int>(padded_groups_per_row), \
static_cast<int>(groups_per_row), static_cast<int>(mn), \
static_cast<int>(output_q_mn_extent), \
static_cast<int>(tma_aligned_mn), num_scale_elems, \
static_cast<float>(eps), static_cast<float>(min_8bit), \
static_cast<float>(max_8bit)); \

#define LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, KX, RY) \
do { \
dim3 grid(static_cast<unsigned int>(blocks_x), \
static_cast<unsigned int>(blocks_y)); \
dim3 block(num_threads); \
per_token_group_quant_8bit_packed_register_kernel<T, DST_DTYPE, 128, KX, \
RY> \
<<<grid, block, 0, stream>>>( \
static_cast<const T*>(input.data_ptr()), output_q.data_ptr(), \
reinterpret_cast<unsigned int*>(output_s_packed.data_ptr()), \
static_cast<int>(padded_groups_per_row), \
static_cast<int>(groups_per_row), static_cast<int>(mn), \
static_cast<int>(output_q_mn_extent), \
static_cast<int>(tma_aligned_mn), num_scale_elems, \
static_cast<float>(eps), static_cast<float>(min_8bit), \
static_cast<float>(max_8bit)); \
} while (0)

#define LAUNCH_REG_KERNEL(T, DST_DTYPE) \
do { \
if (kx == 16) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 16, 1); \
} else if (kx == 8) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 8, 2); \
} else if (kx == 4) { \
LAUNCH_REG_KERNEL_INST(T, DST_DTYPE, 4, 4); \
} else { \
STD_TORCH_CHECK(false, "Unsupported kx value ", kx); \
} \
} while (0)

VLLM_STABLE_DISPATCH_HALF_TYPES(
Expand All @@ -488,6 +516,7 @@ void per_token_group_quant_8bit_packed(const torch::stable::Tensor& input,
}));

#undef LAUNCH_REG_KERNEL
#undef LAUNCH_REG_KERNEL_INST
}

void per_token_group_quant_fp8(const torch::stable::Tensor& input,
Expand Down
Loading