Skip to content
173 changes: 79 additions & 94 deletions csrc/quantization/quantizer.cu
Original file line number Diff line number Diff line change
@@ -1,95 +1,84 @@
#include <math.h>
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"

namespace cg = cooperative_groups;

__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)

cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
cg::thread_block b = cg::this_thread_block(); // tb
cg::thread_block_tile<32> g =
cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64.

int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;

float2* vals_cast = reinterpret_cast<float2*>(vals);
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);

float2 data[MAX_REG];
__half data[vals_per_access];

int group_id = blockIdx.x;

{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;

while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);

if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);

group_index += blockDim.x;
reg_count++;
}
int thread_index = id * vals_per_access;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);

#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
for (int i = 0; i < vals_per_access; i++) {
if (abs((float)data[i]) > max) max = abs((float)data[i]);
}
__shared__ float partialMax[WARP_SIZE];

if (lane == 0) partialMax[gid] = max;

b.sync();

if (lane < warp_num) max = partialMax[lane];
}

#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];

max = g.shfl(max, 0);
if (lane == 0) partialMax[gid] = max;

float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
b.sync();

float2 q_data_int[2];
if (lane < warp_num) max = partialMax[lane];

q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}

q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
max = g.shfl(max, 0);

data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
int q_range_max = (1 << (num_bits - 1)) - 1;
int q_range_min = -(1 << (num_bits - 1));

vals_cast[offset + group_index] = data[i];
}
}
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
#pragma unroll
for (int j = 0; j < vals_per_access; j++) {
float q_data;
q_data = __half2float(data[j]);
q_data = __float2int_rn(q_data * q_scale);
q_data = q_data > (q_range_max) ? (q_range_max)
: (q_data < (q_range_min) ? (q_range_min) : q_data);
data[j] = __float2half_rn(q_data * q_scale_inv);
}
mem_access::store_global<granularity>(vals + offset + thread_index, data);
}

#endif
}

Expand All @@ -103,31 +92,31 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;

float4* vals_cast = reinterpret_cast<float4*>(vals);
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);

float4 data[MAX_REG];
float data[vals_per_access];

int bid = blockIdx.x;

int group_index = bid * group_size + id;
int thread_index = id * vals_per_access;

int reg_count = 0;

float max = -10000.0;
int offset = bid * group_size;

while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
float max = -10000.0;

if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);

group_index += blockDim.x;
id += blockDim.x;
reg_count++;
#pragma unroll
for (int i = 0; i < vals_per_access; i++) {
if (abs(data[i]) > max) max = abs(data[i]);
}
}
id = threadIdx.x;

#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
Expand All @@ -153,25 +142,22 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits)

float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];

float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
int q_range_max = (1 << (num_bits - 1)) - 1;
int q_range_min = -(1 << (num_bits - 1));

q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;

vals_cast[group_index + bid * group_size] = q_data;
}
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
#pragma unroll
for (int j = 0; j < vals_per_access; j++) {
float q_data;
q_data = __float2int_rn(data[j] * q_scale);
q_data = q_data > (q_range_max) ? (q_range_max)
: (q_data < (q_range_min) ? (q_range_min) : q_data);
data[j] = roundf(q_data * q_scale_inv);
}
mem_access::store_global<granularity>(vals + offset + thread_index, data);
}
}

Expand All @@ -185,8 +171,7 @@ void launch_quantize_kernel(T* vals,
dim3 grid_dim(group_num);
dim3 block_dim(1024);

quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(vals, total_count / group_num, num_bits);
}

template void launch_quantize_kernel(float* vals,
Expand Down
42 changes: 23 additions & 19 deletions tests/unit/ops/quantizer/test_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (2e-1, 5e-2), torch.float16: (2e-1, 5e-2)}[x.dtype]
rtol, atol = {torch.float32: (2e-2, 5e-3), torch.float16: (2e-2, 5e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)


Expand All @@ -19,7 +19,7 @@ def quantize_dequantize_ref(inputs, bit, num_groups=1):
input_min = input_flat.amin(-1, keepdim=True)
input_max = input_flat.amax(-1, keepdim=True)

scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs()))
scale = q_range / (2 * torch.max(input_min.abs(), input_max.abs() + 1e-5))
input_flat = (input_flat * scale).round().clamp(-q_range // 2, q_range // 2 - 1)
# dequantize
dequant_flat = torch.t(input_flat.to(torch.int8)) / scale.view(-1).to(torch.float16)
Expand All @@ -35,22 +35,26 @@ def run_quant_dequant(inputs, groups, bits):


@pytest.mark.inference
@pytest.mark.parametrize("tensor_shape", [(8, 8), (128, 256)])
def test_quant_dequant(tensor_shape):
@pytest.mark.parametrize("tensor_shape", [(16, 4096), (128, 256)])
# Test with two tensor shapes as (16, 4096) and (128, 256).
@pytest.mark.parametrize("groups", [1, 16])
# Test with number of quant groups as 1 and 16.
# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
def test_quant_dequant(tensor_shape, groups):

input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda()

# test 8bit quant/dequant on tensor partitioned in 1 group.
ref_input_8bit_1group = input_tensor.clone().detach()
ds_input_8bit_1group = input_tensor.clone().detach()
ref_out_8bit_1group = quantize_dequantize_ref(ref_input_8bit_1group, 8)
# run_quant_dequant will do quantize then dequantize and return the dequantized value.
ds_out_8bit_1group = run_quant_dequant(ds_input_8bit_1group, 1, 8)
assert (allclose(ds_out_8bit_1group, ref_out_8bit_1group))

# test 4bit quant/dequant on tensor partitioned into 16 groups.
# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
ref_input_4bit_16group = input_tensor.clone().detach()
ds_input_4bit_16group = input_tensor.clone().detach()
ref_out_4bit_16group = quantize_dequantize_ref(ref_input_4bit_16group, 4, 16)
ds_out_4bit_16group = run_quant_dequant(ds_input_4bit_16group, 16, 4)
assert (allclose(ds_out_4bit_16group, ref_out_4bit_16group))
# 8-bit quantization.
ref_input_8bit = input_tensor.clone().detach()
ds_input_8bit = input_tensor.clone().detach()
ref_out_8bit = quantize_dequantize_ref(ref_input_8bit, 8, groups)
# run_quant_dequant will do quantize then dequantize, and return the dequantized value.
ds_out_8bit = run_quant_dequant(ds_input_8bit, groups, 8)
assert (allclose(ds_out_8bit, ref_out_8bit))

# 4-bit quantization.
ref_input_4bit = input_tensor.clone().detach()
ds_input_4bit = input_tensor.clone().detach()
ref_out_4bit = quantize_dequantize_ref(ref_input_4bit, 4, groups)
ds_out_4bit = run_quant_dequant(ds_input_4bit, groups, 4)
assert (allclose(ds_out_4bit, ref_out_4bit))