diff --git a/csrc/fp_quantizer/quantize.cu b/csrc/fp_quantizer/quantize.cu index 37be6cc0657c..5f0b58f124f0 100644 --- a/csrc/fp_quantizer/quantize.cu +++ b/csrc/fp_quantizer/quantize.cu @@ -219,119 +219,100 @@ __global__ void apply_quantization(T* val, } template -__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size) +__global__ void apply_dequantization(uint8_t* val, T* q_val, int group_size, int total_num_elements) { - int tidx = threadIdx.x; - int wid = tidx >> 5; - int lane = tidx & 0x1f; - int gid = blockIdx.x * quantization::warps + wid; + constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); + int tidx = (blockIdx.x * blockDim.x + threadIdx.x) * vector_size; + constexpr int quantized_bits = _mantisa_bits + _exponent_bits + 1; constexpr int q_exponent_bits = total_q_bits - q_mantisa_bits - 1; constexpr uint16_t _mantisa_mask = (1 << _mantisa_bits) - 1; constexpr uint16_t _exponent_mask = ((1 << _exponent_bits) - 1) << _mantisa_bits; constexpr uint16_t _sign_mask = 1 << (_mantisa_bits + _exponent_bits); - - constexpr uint32_t vector_size = quantization::access_granularity / sizeof(T); - constexpr uint32_t load_stride = vector_size * hw_warp_size; - const uint32_t thread_offset = lane * vector_size; - const uint32_t thread_load_offset = lane * vector_size * quantized_bits / 8; - const uint32_t base_load_offset = - gid * (group_size * quantized_bits / 8 + 4) + thread_load_offset; // 4-byte scale offset - const uint32_t base_store_offset = gid * group_size + thread_offset; - const uint8_t* load_base_ptr = val + base_load_offset; + const uint32_t g_index = (tidx / group_size); + const uint32_t group_size_bytes = (group_size * quantized_bits / 8); + const uint8_t* load_base_ptr = + val + g_index * (group_size_bytes + 4) + (tidx % group_size) * quantized_bits / 8; int mantisa_mask = ((1 << q_mantisa_bits) - 1); mantisa_mask <<= (_mantisa_bits - q_mantisa_bits); - T* store_base_ptr = q_val + base_store_offset; - float scale; //= q_scale[gid]; + T* store_base_ptr = q_val + tidx; + float scale; uint8_t* scale_as_int8 = reinterpret_cast(&scale); if (quantized_bits == 6) { mem_access::load_global( - scale_as_int8, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); mem_access::load_global( scale_as_int8 + quantization::quanitzed_access_granularity_6bits, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8) + + val + g_index * (group_size_bytes + 4) + group_size_bytes + quantization::quanitzed_access_granularity_6bits); } else mem_access::load_global( - scale_as_int8, - val + gid * (group_size * quantized_bits / 8 + 4) + (group_size * quantized_bits / 8)); - -#pragma unroll - for (int i = 0; i < unroll; i++) { - if (i * load_stride + thread_offset < group_size) { - uint64_t q_buf_in; - uint64_t q_buf_in1; - uint8_t* int8_data = reinterpret_cast(&q_buf_in); - uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); - uint32_t loading_offset = i * load_stride * quantized_bits / 8; - if (quantized_bits == 6) { - mem_access::load_global( - int8_data, load_base_ptr + loading_offset); - mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity_6bits, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity_6bits); - mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity_6bits * 2, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity_6bits * 2); - } else { + scale_as_int8, val + g_index * (group_size_bytes + 4) + group_size_bytes); + + if (tidx < total_num_elements) { + uint64_t q_buf_in; + uint64_t q_buf_in1; + uint8_t* int8_data = reinterpret_cast(&q_buf_in); + uint8_t* int8_data1 = reinterpret_cast(&q_buf_in1); + if (quantized_bits == 6) { + mem_access::load_global( + int8_data, load_base_ptr); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits, + load_base_ptr + quantization::quanitzed_access_granularity_6bits); + mem_access::load_global( + int8_data + quantization::quanitzed_access_granularity_6bits * 2, + load_base_ptr + quantization::quanitzed_access_granularity_6bits * 2); + } else { + mem_access::load_global(int8_data, + load_base_ptr); + if (quantized_bits > 4) { mem_access::load_global( - int8_data, load_base_ptr + loading_offset); - if (quantized_bits > 4) { + int8_data + quantization::quanitzed_access_granularity, + load_base_ptr + quantization::quanitzed_access_granularity); + if (quantized_bits == 12) { mem_access::load_global( - int8_data + quantization::quanitzed_access_granularity, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity); - if (quantized_bits == 12) { - mem_access::load_global( - int8_data1, - load_base_ptr + loading_offset + - quantization::quanitzed_access_granularity * 2); - } + int8_data1, load_base_ptr + quantization::quanitzed_access_granularity * 2); } } - T store_buf[vector_size]; - uint16_t* q_buf = reinterpret_cast(store_buf); + } + T store_buf[vector_size]; + uint16_t* q_buf = reinterpret_cast(store_buf); #pragma unroll - for (int j = 0; j < vector_size; j++) { - uint16_t new_data; - if (j < 5 || quantized_bits != 12) { - new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); - } else { - if (j == 5) { - new_data = (uint16_t)(q_buf_in1); - new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); - } else - new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); - } + for (int j = 0; j < vector_size; j++) { + uint16_t new_data; + if (j < 5 || quantized_bits != 12) { + new_data = (uint16_t)(q_buf_in >> (j * quantized_bits)); + } else { + if (j == 5) { + new_data = (uint16_t)(q_buf_in1); + new_data = (uint16_t)((new_data << 4) | (q_buf_in >> 60)); + } else + new_data = (uint16_t)(q_buf_in1 >> ((j - 6) * quantized_bits + 8)); + } - uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); - uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; - uint16_t dst_mantisa = (new_data & _mantisa_mask); + uint16_t sign = (new_data & _sign_mask) >> (_mantisa_bits + _exponent_bits); + uint16_t dst_exponent = (new_data & _exponent_mask) >> _mantisa_bits; + uint16_t dst_mantisa = (new_data & _mantisa_mask); - if (dst_exponent != (1 << q_exponent_bits) - 1) - dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + - (1 << (q_exponent_bits - 1)) - 1; + if (dst_exponent != (1 << q_exponent_bits) - 1) + dst_exponent = (dst_exponent - ((1 << (_exponent_bits - 1)) - 1)) + + (1 << (q_exponent_bits - 1)) - 1; - q_buf[j] = ((sign << (q_exponent_bits + q_mantisa_bits)) | - (dst_exponent << q_mantisa_bits) | - (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); - float up_cast = conversion::to(store_buf[j]); - store_buf[j] = conversion::to(up_cast * scale); - } - mem_access::store_global( - store_base_ptr + i * load_stride, store_buf); + q_buf[j] = + ((sign << (q_exponent_bits + q_mantisa_bits)) | (dst_exponent << q_mantisa_bits) | + (dst_mantisa << (q_mantisa_bits - _mantisa_bits))); + float up_cast = conversion::to(store_buf[j]); + store_buf[j] = conversion::to(up_cast * scale); } + mem_access::store_global(store_base_ptr, store_buf); } } @@ -386,12 +367,6 @@ INSTANTIATE_LAUNCH_QUANTIZATION(__nv_bfloat16, 23, 8); #endif INSTANTIATE_LAUNCH_QUANTIZATION(__half, 23, 8); -#define LAUNCH_FOR_DEQUANTIZATION_UNROLL(COUNT) \ - case COUNT: \ - apply_dequantization \ - <<>>(val, q_val, group_size); \ - break; - template void launch_dequantization(uint8_t* val, T* q_val, @@ -401,21 +376,14 @@ void launch_dequantization(uint8_t* val, int q_exponent_bits, cudaStream_t stream) { - const dim3 grid((num_groups + quantization::warps - 1) / quantization::warps); + int blocks = ((num_groups * group_size) - 1) / + (quantization::threads * (quantization::access_granularity / sizeof(T))) + + 1; + const dim3 grid(blocks); const dim3 block(quantization::threads); - - constexpr int vals_per_unroll = hw_warp_size * quantization::access_granularity / sizeof(T); - const int copy_unroll = (group_size + vals_per_unroll - 1) / vals_per_unroll; - DEQUANT_SWITCH(q_mantisa_bits * q_exponent_bits, [&] { - switch (copy_unroll) { - LAUNCH_FOR_DEQUANTIZATION_UNROLL(1) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(2) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(3) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(4) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(5) - LAUNCH_FOR_DEQUANTIZATION_UNROLL(6) - } + apply_dequantization + <<>>(val, q_val, group_size, (num_groups * group_size)); }); } #define INSTANTIATE_LAUNCH_DEQUANTIZATION(T, mantisa) \