From a8b821535aa0b254efa681d51b4951734ca021cc Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Wed, 10 Apr 2024 11:50:47 -0700 Subject: [PATCH] Optimize the fp-dequantizer to get high memory-BW utilization (#5373) This PR removes the for loop inside the dequantizer kernel and use as many threads and blocks as needed to dequantize the quantized matrix. The previous implementation was processing each group per thread block which can reduce the efficiency when have having smaller group-size and also processes more data per-thread which is unnecessary and we can use more parallelism to improve the dequantization performance. Based on my testing results, for a 4K by 4K matrix, dequantizing from fp8 to bf16 gives 2.5x speedup (improving the BW efficiency from 1 TB/s to 2.5 TB/s on Nvidia H100 GPU). --------- Co-authored-by: Reza Yazdani --- csrc/fp_quantizer/quantize.cu | 168 ++++++++++++++-------------------- 1 file changed, 68 insertions(+), 100 deletions(-) diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu index 37be6cc0657c..5f0b58f124f0 100644 --- a/csrc/fp_quantizer/quantize.cu +++ b/csrc/fp_quantizer/quantize.cu @@ -219,119 +219,100 @@ __global__ void apply_quantization(T* val, } template -__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size) +__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements) { - int tidx = threadIdx.x; - int wid = tidx >> 5; - int lane = tidx & 0x1f; - int gid = blockIdx.x * quantization::warps + wid; + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); - - constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); - constexpr uint32_t load_stride = vector_size * hw_warp_size; - const uint32_t thread_offset = lane * vector_size; - const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8; - const uint32_t base_load_offset = - gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset - const uint32_t base_store_offset = gid * group_size + thread_offset; - const uint8_t* load_base_ptr = val + base_load_offset; + const uint32_t g_index = (tidx / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8; int mantisa_mask = ((1 << q_mantisa_bits) - 1); mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); - T* store_base_ptr = q_val + base_store_offset; - float scale; //= q_scale[gid]; + T* store_base_ptr = q_val + tidx; + float scale; uint8_t* scale_as_int8 = reinterpret_cast(&scale); if (quantized_bits == 6) { mem_access::load_global( - scale_as_int8, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); mem_access::load_global( scale_as_int8 + quantization::quanitzed_access_granularity_6bits, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) + + val + g_index * (group_size_bytes + 4) + group_size_bytes + quantization::quanitzed_access_granularity_6bits); } else mem_access::load_global( - scale_as_int8, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); - -#pragma unroll - for (int i = 0; i < unroll; i++) { - if (i * load_stride + thread_offset < group_size) { - uint64_t q_buf_in; - uint64_t q_buf_in1; - uint8_t* int8_data = reinterpret_cast(&q_buf_in); - uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); - uint32_t loading_offset = i * load_stride * quantized_bits / 8; - if (quantized_bits == 6) { - mem_access::load_global( - int8_data, load_base_ptr + loading_offset); - mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity_6bits, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity_6bits); - mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity_6bits * 2, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity_6bits * 2); - } else { + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + + if (tidx < total_num_elements) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global(int8_data, + load_base_ptr); + if (quantized_bits > 4) { mem_access::load_global( - int8_data, load_base_ptr + loading_offset); - if (quantized_bits > 4) { + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity); - if (quantized_bits == 12) { - mem_access::load_global( - int8_data1, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity * 2); - } + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); } } - T store_buf[vector_size]; - uint16_t* q_buf = reinterpret_cast(store_buf); + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); #pragma unroll - for (int j = 0; j < vector_size; j++) { - uint16_t new_data; - if (j < 5 || quantized_bits != 12) { - new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); - } else { - if (j == 5) { - new_data = (uint16_t)(q_buf_in1); - new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); - } else - new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); - } + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } - uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); - uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; - uint16_t dst_mantisa = (new_data & _mantisa_mask); + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); - if (dst_exponent != (1 << q_exponent_bits) - 1) - dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + - (1 << (q_exponent_bits - 1)) - 1; + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; - q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) | - (dst_exponent << q_mantisa_bits) | - (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); - float up_cast = conversion::to(store_buf[j]); - store_buf[j] = conversion::to(up_cast * scale); - } - mem_access::store_global( - store_base_ptr + i * load_stride, store_buf); + q_buf[j] = + ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); } + mem_access::store_global(store_base_ptr, store_buf); } } @@ -386,12 +367,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); #endif INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); -#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \ - case COUNT: \ - apply_dequantization \ - <<>>(val, q_val, group_size); \ - break; - template void launch_dequantization(uint8_t* val, T* q_val, @@ -401,21 +376,14 @@ void launch_dequantization(uint8_t* val, int q_exponent_bits, cudaStream_t stream) { - const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + int blocks = ((num_groups * group_size) - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; + const dim3 grid(blocks); const dim3 block(quantization::threads); - - constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); - const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; - DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { - switch (copy_unroll) { - LAUNCH_FOR_DEQUANTIZATION_UNROLL(1) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(2) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(3) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(4) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(5) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(6) - } + apply_dequantization + <<>>(val, q_val, group_size, (num_groups * group_size)); }); } #define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \