diff --git a/csrc/quantization/quantizer.cu b/csrc/quantization/quantizer.cu index 37883410e976..41a11b3cfa53 100644 --- a/csrc/quantization/quantizer.cu +++ b/csrc/quantization/quantizer.cu @@ -1,5 +1,6 @@ #include #include "custom_cuda_layers.h" +#include "memory_access_utils.h" namespace cg = cooperative_groups; @@ -7,89 +8,77 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + cg::thread_block b = cg::this_thread_block(); // tb + cg::thread_block_tile<32> g = + cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64. int gid = threadIdx.x >> 5; int lane = threadIdx.x & 0x1f; int warp_num = blockDim.x >> 5; int id = threadIdx.x; - float2* vals_cast = reinterpret_cast(vals); + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(__half); - float2 data[MAX_REG]; + __half data[vals_per_access]; int group_id = blockIdx.x; - { - int group_index = id; - int reg_count = 0; - int offset = group_id * group_size; - float max = -10000.0; - - while (group_index < group_size && reg_count < MAX_REG) { - data[reg_count] = vals_cast[offset + group_index]; - __half* data_h = reinterpret_cast<__half*>(&data[reg_count]); - - if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]); - if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]); - if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]); - if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]); - - group_index += blockDim.x; - reg_count++; - } + int thread_index = id * vals_per_access; + int reg_count = 0; + int offset = group_id * group_size; + float max = -10000.0; + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); #pragma unroll - for (int i = 1; i < WARP_SIZE; i <<= 1) { - auto temp = g.shfl_xor(max, i); - if (max < temp) max = temp; + for (int i = 0; i < vals_per_access; i++) { + if (abs((float)data[i]) > max) max = abs((float)data[i]); } - __shared__ float partialMax[WARP_SIZE]; - - if (lane == 0) partialMax[gid] = max; - - b.sync(); - - if (lane < warp_num) max = partialMax[lane]; + } #pragma unroll - for (int i = 1; i < WARP_SIZE; i <<= 1) { - auto temp = g.shfl_down(max, i); - if (max < temp) max = temp; - } + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_xor(max, i); + if (max < temp) max = temp; + } + __shared__ float partialMax[WARP_SIZE]; - max = g.shfl(max, 0); + if (lane == 0) partialMax[gid] = max; - float q_scale = (1 << num_bits) / (2 * max + 1e-5); - float q_scale_inv = 1 / q_scale; - for (int i = 0; i < reg_count; i++) { - group_index = i * blockDim.x + id; - if (group_index < group_size) { - __half2* data_h = reinterpret_cast<__half2*>(&data[i]); - float2 q_data[2]; - q_data[0] = __half22float2(data_h[0]); - q_data[1] = __half22float2(data_h[1]); + b.sync(); - float2 q_data_int[2]; + if (lane < warp_num) max = partialMax[lane]; - q_data_int[0].x = roundf(q_data[0].x * q_scale); - q_data_int[0].y = roundf(q_data[0].y * q_scale); - q_data_int[1].x = roundf(q_data[1].x * q_scale); - q_data_int[1].y = roundf(q_data[1].y * q_scale); +#pragma unroll + for (int i = 1; i < WARP_SIZE; i <<= 1) { + auto temp = g.shfl_down(max, i); + if (max < temp) max = temp; + } - q_data_int[0].x *= q_scale_inv; - q_data_int[0].y *= q_scale_inv; - q_data_int[1].x *= q_scale_inv; - q_data_int[1].y *= q_scale_inv; + max = g.shfl(max, 0); - data_h[0] = __float22half2_rn(q_data_int[0]); - data_h[1] = __float22half2_rn(q_data_int[1]); + float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5); + float q_scale_inv = 1 / q_scale; + int q_range_max = (1 << (num_bits - 1)) - 1; + int q_range_min = -(1 << (num_bits - 1)); - vals_cast[offset + group_index] = data[i]; - } - } + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); +#pragma unroll + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = __half2float(data[j]); + q_data = __float2int_rn(q_data * q_scale); + q_data = q_data > (q_range_max) ? (q_range_max) + : (q_data < (q_range_min) ? (q_range_min) : q_data); + data[j] = __float2half_rn(q_data * q_scale_inv); + } + mem_access::store_global(vals + offset + thread_index, data); } + #endif } @@ -103,31 +92,31 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) int warp_num = blockDim.x >> 5; int id = threadIdx.x; - float4* vals_cast = reinterpret_cast(vals); + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(float); - float4 data[MAX_REG]; + float data[vals_per_access]; int bid = blockIdx.x; - int group_index = bid * group_size + id; + int thread_index = id * vals_per_access; + int reg_count = 0; - float max = -10000.0; + int offset = bid * group_size; - while (id < group_size && reg_count < MAX_REG) { - float4 data_reg = vals_cast[group_index]; - data[reg_count] = data_reg; + float max = -10000.0; - if (abs(data_reg.x) > max) max = abs(data_reg.x); - if (abs(data_reg.y) > max) max = abs(data_reg.y); - if (abs(data_reg.z) > max) max = abs(data_reg.z); - if (abs(data_reg.w) > max) max = abs(data_reg.w); + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); - group_index += blockDim.x; - id += blockDim.x; - reg_count++; +#pragma unroll + for (int i = 0; i < vals_per_access; i++) { + if (abs(data[i]) > max) max = abs(data[i]); + } } - id = threadIdx.x; + #pragma unroll for (int i = 1; i < WARP_SIZE; i <<= 1) { auto temp = g.shfl_xor(max, i); @@ -153,25 +142,22 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits) float q_scale = (1 << num_bits) / (2 * max + 1e-5); float q_scale_inv = 1 / q_scale; - for (int i = 0; i < reg_count; i++) { - group_index = i * blockDim.x + id; - if (group_index < group_size) { - float4 q_data; - q_data = data[i]; - float4 q_data_int; - q_data_int.x = roundf(q_data.x * q_scale); - q_data_int.y = roundf(q_data.y * q_scale); - q_data_int.w = roundf(q_data.w * q_scale); - q_data_int.z = roundf(q_data.z * q_scale); + int q_range_max = (1 << (num_bits - 1)) - 1; + int q_range_min = -(1 << (num_bits - 1)); - q_data.x = q_data_int.x * q_scale_inv; - q_data.y = q_data_int.y * q_scale_inv; - q_data.w = q_data_int.w * q_scale_inv; - q_data.z = q_data_int.z * q_scale_inv; - - vals_cast[group_index + bid * group_size] = q_data; - } + for (int thread_index = id * vals_per_access; thread_index < group_size; + thread_index += blockDim.x * vals_per_access) { + mem_access::load_global(data, vals + offset + thread_index); +#pragma unroll + for (int j = 0; j < vals_per_access; j++) { + float q_data; + q_data = __float2int_rn(data[j] * q_scale); + q_data = q_data > (q_range_max) ? (q_range_max) + : (q_data < (q_range_min) ? (q_range_min) : q_data); + data[j] = roundf(q_data * q_scale_inv); + } + mem_access::store_global(vals + offset + thread_index, data); } } @@ -185,8 +171,7 @@ void launch_quantize_kernel(T* vals, dim3 grid_dim(group_num); dim3 block_dim(1024); - quantize_kernel<<>>( - vals, (total_count / group_num) / 4, num_bits); + quantize_kernel<<>>(vals, total_count / group_num, num_bits); } template void launch_quantize_kernel(float* vals, diff --git a/tests/unit/ops/quantizer/test_quant.py b/tests/unit/ops/quantizer/test_quant.py index ea6b35860873..1526937dd2bc 100644 --- a/tests/unit/ops/quantizer/test_quant.py +++ b/tests/unit/ops/quantizer/test_quant.py @@ -7,7 +7,7 @@ def allclose(x, y): assert x.dtype == y.dtype - rtol, atol = {torch.float32: (2e-1, 5e-2), torch.float16: (2e-1, 5e-2)}[x.dtype] + rtol, atol = {torch.float32: (2e-2, 5e-3), torch.float16: (2e-2, 5e-3)}[x.dtype] return torch.allclose(x, y, rtol=rtol, atol=atol) @@ -19,7 +19,7 @@ def quantize_dequantize_ref(inputs, bit, num_groups=1): input_min = input_flat.amin(-1, keepdim=True) input_max = input_flat.amax(-1, keepdim=True) - scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs())) + scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs() + 1e-5)) input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1) # dequantize dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16) @@ -35,22 +35,26 @@ def run_quant_dequant(inputs, groups, bits): @pytest.mark.inference -@pytest.mark.parametrize("tensor_shape", [(8, 8), (128, 256)]) -def test_quant_dequant(tensor_shape): +@pytest.mark.parametrize("tensor_shape", [(16, 4096), (128, 256)]) +# Test with two tensor shapes as (16, 4096) and (128, 256). +@pytest.mark.parametrize("groups", [1, 16]) +# Test with number of quant groups as 1 and 16. +# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG. +def test_quant_dequant(tensor_shape, groups): + input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda() - # test 8bit quant/dequant on tensor partitioned in 1 group. - ref_input_8bit_1group = input_tensor.clone().detach() - ds_input_8bit_1group = input_tensor.clone().detach() - ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8) - # run_quant_dequant will do quantize then dequantize and return the dequantized value. - ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, 1, 8) - assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group)) - - # test 4bit quant/dequant on tensor partitioned into 16 groups. - # Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG. - ref_input_4bit_16group = input_tensor.clone().detach() - ds_input_4bit_16group = input_tensor.clone().detach() - ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, 16) - ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, 16, 4) - assert (allclose(ds_out_4bit_16group, ref_out_4bit_16group)) + # 8-bit quantization. + ref_input_8bit = input_tensor.clone().detach() + ds_input_8bit = input_tensor.clone().detach() + ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups) + # run_quant_dequant will do quantize then dequantize, and return the dequantized value. + ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8) + assert (allclose(ds_out_8bit, ref_out_8bit)) + + # 4-bit quantization. + ref_input_4bit = input_tensor.clone().detach() + ds_input_4bit = input_tensor.clone().detach() + ref_out_4bit = quantize_dequantize_ref(ref_input_4bit, 4, groups) + ds_out_4bit = run_quant_dequant(ds_input_4bit, groups, 4) + assert (allclose(ds_out_4bit, ref_out_4bit))