Skip to content

Commit

Permalink
Optimize the fp-dequantizer to get high memory-BW utilization (#5373)
Browse files Browse the repository at this point in the history
This PR removes the for loop inside the dequantizer kernel and use as
many threads and blocks as needed to dequantize the quantized matrix.
The previous implementation was processing each group per thread block
which can reduce the efficiency when have having smaller group-size and
also processes more data per-thread which is unnecessary and we can use
more parallelism to improve the dequantization performance.

Based on my testing results, for a 4K by 4K matrix, dequantizing from
fp8 to bf16 gives 2.5x speedup (improving the BW efficiency from 1 TB/s
to 2.5 TB/s on Nvidia H100 GPU).

---------

Co-authored-by: Reza Yazdani <reza.yazdani@snowflake.com>
  • Loading branch information
RezaYazdaniAminabadi and sfc-gh-reyazda authored Apr 10, 2024
1 parent cc9e7b9 commit a8b8215
Showing 1 changed file with 68 additions and 100 deletions.
168 changes: 68 additions & 100 deletions csrc/fp_quantizer/quantize.cu
Original file line number Diff line number Diff line change
@@ -219,119 +219,100 @@ __global__ void apply_quantization(T* val,
}

template <typename T,
int unroll,
int q_mantisa_bits,
int total_q_bits = 16,
int _mantisa_bits = 3,
int _exponent_bits = 4>
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size)
__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements)
{
int tidx = threadIdx.x;
int wid = tidx >> 5;
int lane = tidx & 0x1f;
int gid = blockIdx.x * quantization::warps + wid;
constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size;

constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1;
constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1;
constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1;
constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits;
constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits);

constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T);
constexpr uint32_t load_stride = vector_size * hw_warp_size;
const uint32_t thread_offset = lane * vector_size;
const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8;
const uint32_t base_load_offset =
gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset
const uint32_t base_store_offset = gid * group_size + thread_offset;
const uint8_t* load_base_ptr = val + base_load_offset;
const uint32_t g_index = (tidx / group_size);
const uint32_t group_size_bytes = (group_size * quantized_bits / 8);
const uint8_t* load_base_ptr =
val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8;

int mantisa_mask = ((1 << q_mantisa_bits) - 1);
mantisa_mask <<= (_mantisa_bits - q_mantisa_bits);

T* store_base_ptr = q_val + base_store_offset;
float scale; //= q_scale[gid];
T* store_base_ptr = q_val + tidx;
float scale;

uint8_t* scale_as_int8 = reinterpret_cast<uint8_t*>(&scale);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
scale_as_int8 + quantization::quanitzed_access_granularity_6bits,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) +
val + g_index * (group_size_bytes + 4) + group_size_bytes +
quantization::quanitzed_access_granularity_6bits);
} else
mem_access::load_global<quantization::quanitzed_access_granularity>(
scale_as_int8,
val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8));

#pragma unroll
for (int i = 0; i < unroll; i++) {
if (i * load_stride + thread_offset < group_size) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
uint32_t loading_offset = i * load_stride * quantized_bits / 8;
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr + loading_offset);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity_6bits * 2);
} else {
scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes);

if (tidx < total_num_elements) {
uint64_t q_buf_in;
uint64_t q_buf_in1;
uint8_t* int8_data = reinterpret_cast<uint8_t*>(&q_buf_in);
uint8_t* int8_data1 = reinterpret_cast<uint8_t*>(&q_buf_in1);
if (quantized_bits == 6) {
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data, load_base_ptr);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits,
load_base_ptr + quantization::quanitzed_access_granularity_6bits);
mem_access::load_global<quantization::quanitzed_access_granularity_6bits>(
int8_data + quantization::quanitzed_access_granularity_6bits * 2,
load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2);
} else {
mem_access::load_global<quantization::quanitzed_access_granularity>(int8_data,
load_base_ptr);
if (quantized_bits > 4) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data, load_base_ptr + loading_offset);
if (quantized_bits > 4) {
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data + quantization::quanitzed_access_granularity,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity);
if (quantized_bits == 12) {
mem_access::load_global<quantization::quanitzed_access_granularity>(
int8_data1,
load_base_ptr + loading_offset +
quantization::quanitzed_access_granularity * 2);
}
int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2);
}
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
}
T store_buf[vector_size];
uint16_t* q_buf = reinterpret_cast<uint16_t*>(store_buf);
#pragma unroll
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}
for (int j = 0; j < vector_size; j++) {
uint16_t new_data;
if (j < 5 || quantized_bits != 12) {
new_data = (uint16_t)(q_buf_in >> (j * quantized_bits));
} else {
if (j == 5) {
new_data = (uint16_t)(q_buf_in1);
new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60));
} else
new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8));
}

uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);
uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits);
uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits;
uint16_t dst_mantisa = (new_data & _mantisa_mask);

if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;
if (dst_exponent != (1 << q_exponent_bits) - 1)
dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) +
(1 << (q_exponent_bits - 1)) - 1;

q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) |
(dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(
store_base_ptr + i * load_stride, store_buf);
q_buf[j] =
((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) |
(dst_mantisa << (q_mantisa_bits - _mantisa_bits)));
float up_cast = conversion::to<float>(store_buf[j]);
store_buf[j] = conversion::to<T>(up_cast * scale);
}
mem_access::store_global<quantization::access_granularity>(store_base_ptr, store_buf);
}
}

@@ -386,12 +367,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8);
#endif
INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8);

#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \
case COUNT: \
apply_dequantization<T, COUNT, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS> \
<<<grid, block, 0, stream>>>(val, q_val, group_size); \
break;

template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
@@ -401,21 +376,14 @@ void launch_dequantization(uint8_t* val,
int q_exponent_bits,
cudaStream_t stream)
{
const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps);
int blocks = ((num_groups * group_size) - 1) /
(quantization::threads * (quantization::access_granularity / sizeof(T))) +
1;
const dim3 grid(blocks);
const dim3 block(quantization::threads);

constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T);
const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll;

DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] {
switch (copy_unroll) {
LAUNCH_FOR_DEQUANTIZATION_UNROLL(1)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(2)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(3)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(4)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(5)
LAUNCH_FOR_DEQUANTIZATION_UNROLL(6)
}
apply_dequantization<T, mantisa, 16, CONST_Q_MANTISA_BITS, CONST_Q_EXPONENT_BITS>
<<<grid, block, 0, stream>>>(val, q_val, group_size, (num_groups * group_size));
});
}
#define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \

0 comments on commit a8b8215

Please sign in to comment.