diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index bf46cce60a23..89a461c0f9aa 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -3,6 +3,7 @@ #include #include "../../dispatch_utils.h" +#include "../vectorization.cuh" #ifndef USE_ROCM #include @@ -115,8 +116,25 @@ __global__ void static_scaled_int8_quant_kernel( out += token_idx * hidden_size; input += token_idx * hidden_size; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[i] = float_to_int8_rn(static_cast(input[i]) / scale); + int vec_size = hidden_size / 4; + int rem = hidden_size % 4; + vec4_t const* vec_input = + reinterpret_cast const*>(input); + q8x4_t* vec_out = reinterpret_cast*>(out); + for (int i = tid; i < vec_size; i += blockDim.x) { + vec4_t in_vec = vec_input[i]; + q8x4_t out_vec; + out_vec.x = float_to_int8_rn(static_cast(in_vec.x) / scale); + out_vec.y = float_to_int8_rn(static_cast(in_vec.y) / scale); + out_vec.z = float_to_int8_rn(static_cast(in_vec.z) / scale); + out_vec.w = float_to_int8_rn(static_cast(in_vec.w) / scale); + vec_out[i] = out_vec; + } + + int base = vec_size * 4; + for (int i = tid; i < rem; i += blockDim.x) { + out[base + i] = + float_to_int8_rn(static_cast(input[base + i]) / scale); } }