diff --git a/csrc/quantization/gptq/autogptq_cuda_256.cpp b/csrc/quantization/gptq/autogptq_cuda_256.cpp new file mode 100644 index 000000000000..9229eea8c512 --- /dev/null +++ b/csrc/quantization/gptq/autogptq_cuda_256.cpp @@ -0,0 +1,72 @@ +#include +#include +#include + +void vecquant2matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant2matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant3matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant4matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + +} \ No newline at end of file diff --git a/csrc/quantization/gptq/autogptq_cuda_64.cpp b/csrc/quantization/gptq/autogptq_cuda_64.cpp new file mode 100644 index 000000000000..f3e8502c0654 --- /dev/null +++ b/csrc/quantization/gptq/autogptq_cuda_64.cpp @@ -0,0 +1,71 @@ +#include +#include +#include + +void vecquant2matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant2matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant3matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant3matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant4matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant4matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + +void vecquant8matmul_cuda( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +); + +void vecquant8matmul( + torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor scales, torch::Tensor zeros, + torch::Tensor g_idx +) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); + m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA) (desc_act)"); +} \ No newline at end of file diff --git a/csrc/quantization/gptq/autogptq_cuda_kernel_256.cu b/csrc/quantization/gptq/autogptq_cuda_kernel_256.cu new file mode 100644 index 000000000000..656c8fea4fe6 --- /dev/null +++ b/csrc/quantization/gptq/autogptq_cuda_kernel_256.cu @@ -0,0 +1,654 @@ +#include +#include +#include +#include +#include + +// atomicAdd for double-precision floating-point numbers on hardware with +// compute capability < 6.0 from: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 +// __device__ double atomicAdd( +// double* address, +// double val +// ) { +// unsigned long long int* address_as_ull = (unsigned long long int*)address; +// unsigned long long int old = *address_as_ull, assumed; +// +// do { +// assumed = old; +// old = atomicCAS( +// address_as_ull, +// assumed, +// __double_as_longlong(val + __longlong_as_double(assumed)) +// ); +// +// // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) +// } while (assumed != old); +// +// return __longlong_as_double(old); +// } +// #endif + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) +// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh + +__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { + unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); + hsum += val; + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (hsum << 16) + : (old & 0xffff0000) | hsum; + old = atomicCAS(address_as_ui, assumed, old); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); +} +__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) { + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} +#endif + + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant2MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant3MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant4MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant8MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant2MatMulKernelFaster_old( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant3MatMulKernelFaster_old( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant4MatMulKernelFaster_old( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + + +const int BLOCKWIDTH = 256; +const int BLOCKHEIGHT2 = 16; +const int BLOCKHEIGHT3 = 24; +const int BLOCKHEIGHT4 = 32; +const int BLOCKHEIGHT8 = 64; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + + +void vecquant2matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant2matmul_cuda", ([&] { + VecQuant2MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 16; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 16; + int z_mod = (w % 16) * 2; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); + int k_bit = (k % 16) * 2; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + +void vecquant3matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant3matmul_cuda", ([&] { + VecQuant3MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = (h / 3) * 32; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = (w / 32) * 3; + int z_mod = w % 32; + int z_bit; + unsigned int z_tmp; + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit -= 22; + z_bit *= 3; + z_bit += 2; + z_w += 2; + } else if (z_bit > 10){ + z_bit -= 11; + z_bit *= 3; + z_bit += 1; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; + int k_mod = k % 32; + int k_bit; + + if (k_mod != 10){ + if (k_mod != 21){ + k_bit = k_mod; + if (k_bit > 21){ + k_bit -= 22; + k_bit *= 3; + k_bit += 2; + k_w += 2; + } else if (k_bit > 10){ + k_bit -= 11; + k_bit *= 3; + k_bit += 1; + k_w += 1; + } else { + k_bit *= 3; + } + } else { + k_w += 1; + } + } + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero; + if (z_mod == 10) { + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); + zero = scalar_t((z_tmp) + 1); + } else if (z_mod == 21){ + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); + zero = scalar_t((z_tmp) + 1); + } else { + zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + } + + if (k_mod == 10) { + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); + } else if (k_mod == 21){ + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6); + } else { + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7); + } + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 8; + int k; + unsigned int g; + scalar_t w_tmp; + + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 4; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} \ No newline at end of file diff --git a/csrc/quantization/gptq/autogptq_cuda_kernel_64.cu b/csrc/quantization/gptq/autogptq_cuda_kernel_64.cu new file mode 100644 index 000000000000..30ba40c291f6 --- /dev/null +++ b/csrc/quantization/gptq/autogptq_cuda_kernel_64.cu @@ -0,0 +1,655 @@ +#include +#include +#include +#include +#include + +// atomicAdd for double-precision floating-point numbers on hardware with +// compute capability < 6.0 from: +// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 +// __device__ double atomicAdd( +// double* address, +// double val +// ) { +// unsigned long long int* address_as_ull = (unsigned long long int*)address; +// unsigned long long int old = *address_as_ull, assumed; +// +// do { +// assumed = old; +// old = atomicCAS( +// address_as_ull, +// assumed, +// __double_as_longlong(val + __longlong_as_double(assumed)) +// ); +// +// // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) +// } while (assumed != old); +// +// return __longlong_as_double(old); +// } +// #endif + + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) +// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh +__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { + unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + unsigned short hsum = reinterpret_cast(address) & 2 ? (old >> 16) : (old & 0xffff); + hsum += val; + old = reinterpret_cast(address) & 2 + ? (old & 0xffff) | (hsum << 16) + : (old & 0xffff0000) | hsum; + old = atomicCAS(address_as_ui, assumed, old); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) + } while (assumed != old); +} +__device__ __forceinline__ void atomicAdd(__half* address, c10::Half val) { + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } while (assumed != old); +} +#endif + + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +); + +template +__global__ void VecQuant2MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant3MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant4MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +template +__global__ void VecQuant8MatMulKernel_old( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant2MatMulKernelFaster_old( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant3MatMulKernelFaster_old( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + +__global__ void VecQuant4MatMulKernelFaster_old( + const half2* __restrict__ vec, + const int* __restrict__ mat, + float* __restrict__ mul, + const float* __restrict__ scales, + const int* __restrict__ zeros, + int batch, + int vec_height, + int height, + int width, + int zero_width, + int groupsize +); + + +const int BLOCKWIDTH = 64; +const int BLOCKHEIGHT2 = 4; +const int BLOCKHEIGHT3 = 6; +const int BLOCKHEIGHT4 = 8; +const int BLOCKHEIGHT8 = 16; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + + +void vecquant2matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant2matmul_cuda", ([&] { + VecQuant2MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant2MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT2 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 16; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 16; + int z_mod = (w % 16) * 2; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); + int k_bit = (k % 16) * 2; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + +void vecquant3matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant3matmul_cuda", ([&] { + VecQuant3MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant3MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT3 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = (h / 3) * 32; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = (w / 32) * 3; + int z_mod = w % 32; + int z_bit; + unsigned int z_tmp; + if (z_mod != 10){ + if (z_mod != 21){ + z_bit = z_mod; + if (z_bit > 21){ + z_bit -= 22; + z_bit *= 3; + z_bit += 2; + z_w += 2; + } else if (z_bit > 10){ + z_bit -= 11; + z_bit *= 3; + z_bit += 1; + z_w += 1; + } else { + z_bit *= 3; + } + } else { + z_w += 1; + } + } + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; + int k_mod = k % 32; + int k_bit; + + if (k_mod != 10){ + if (k_mod != 21){ + k_bit = k_mod; + if (k_bit > 21){ + k_bit -= 22; + k_bit *= 3; + k_bit += 2; + k_w += 2; + } else if (k_bit > 10){ + k_bit -= 11; + k_bit *= 3; + k_bit += 1; + k_w += 1; + } else { + k_bit *= 3; + } + } else { + k_w += 1; + } + } + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero; + if (z_mod == 10) { + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); + zero = scalar_t((z_tmp) + 1); + } else if (z_mod == 21){ + z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); + zero = scalar_t((z_tmp) + 1); + } else { + zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + } + + if (k_mod == 10) { + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); + } else if (k_mod == 21){ + w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6); + } else { + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7); + } + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT4 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 8; + int k; + unsigned int g; + scalar_t w_tmp; + + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} + +void vecquant8matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant8matmul_cuda", ([&] { + VecQuant8MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} + +template +__global__ void VecQuant8MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT8 * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 4; + int k; + unsigned int g; + scalar_t w_tmp; + + int z_w = w / 4; + int z_mod = (w % 4) * 8; + + float weight[BLOCKWIDTH]; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); + int k_bit = (k % 4) * 8; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); + + weight[k] = scale * (w_tmp - zero); + } + + scalar_t res; + for (int b = 0; b < batch; ++b){ + res = 0; + + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + __syncthreads(); + for (k = 0; k < BLOCKWIDTH; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + __syncthreads(); + } +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 92ba0a716c45..62fd3f986d4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. aioprometheus[starlette] +triton==2.0.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 811d494e7a01..0336f4cbb554 100644 --- a/setup.py +++ b/setup.py @@ -237,6 +237,18 @@ def get_torch_arch_list() -> Set[str]: ) ext_modules.append(vllm_extension) +autogptq_extentions = [ + CUDAExtension("autogptq_cuda_64", [ + "csrc/quantization/gptq/autogptq_cuda_64.cpp", + "csrc/quantization/gptq//autogptq_cuda_kernel_64.cu" + ]), + CUDAExtension("autogptq_cuda_256", [ + "csrc/quantization/gptq/autogptq_cuda_256.cpp", + "csrc/quantization/gptq/autogptq_cuda_kernel_256.cu" + ]) +] +ext_modules.extend(autogptq_extentions) + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5190de65d795..2976af52a576 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -278,8 +278,8 @@ def weight_loader(self, # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.storage_bits_size * param.weight_bits + shard_offset = shard_offset // param.storage_bits_size * param.weight_bits loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -295,8 +295,8 @@ def weight_loader(self, # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.storage_bits_size * param.weight_bits + shard_offset = shard_offset // param.storage_bits_size * param.weight_bits param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size @@ -395,8 +395,8 @@ def weight_loader(self, # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.storage_bits_size * param.weight_bits + shard_offset = shard_offset // param.storage_bits_size * param.weight_bits loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -419,8 +419,8 @@ def weight_loader(self, # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = shard_size // param.storage_bits_size * param.weight_bits + shard_offset = shard_offset // param.storage_bits_size * param.weight_bits param_data = param_data.narrow(output_dim, shard_offset, shard_size) shard_id = tp_rank // self.num_kv_head_replicas diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 831576b1d7cd..3b3b3e25e0ef 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -8,6 +8,8 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +STORAGE_BITS_SIZE = 32 + class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -29,7 +31,7 @@ def __init__( raise ValueError( "Currently, only 4-bit weight quantization is supported for " f"AWQ, but got {self.weight_bits} bits.") - self.pack_factor = 32 // self.weight_bits + self.pack_factor = STORAGE_BITS_SIZE // self.weight_bits def __repr__(self) -> str: return (f"AWQConfig(weight_bits={self.weight_bits}, " @@ -86,7 +88,7 @@ def create_weights(self, input_size_per_partition: int, "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") - if output_size_per_partition % self.quant_config.pack_factor != 0: + if output_size_per_partition % STORAGE_BITS_SIZE != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " @@ -95,7 +97,8 @@ def create_weights(self, input_size_per_partition: int, qweight = Parameter( torch.empty( input_size_per_partition, - output_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition // STORAGE_BITS_SIZE * + self.quant_config.weight_bits, device="cuda", dtype=torch.int32, ), @@ -106,12 +109,14 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, + "weight_bits": self.quant_config.weight_bits, + "storage_bits_size": STORAGE_BITS_SIZE, }) qzeros = Parameter( torch.empty( input_size_per_partition // self.quant_config.group_size, - output_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition // STORAGE_BITS_SIZE * + self.quant_config.weight_bits, device="cuda", dtype=torch.int32, ), @@ -122,7 +127,8 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, + "weight_bits": self.quant_config.weight_bits, + "storage_bits_size": STORAGE_BITS_SIZE, }) scales = Parameter( torch.empty( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 8fe96e7ddb98..6f1c060f69df 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -4,12 +4,42 @@ import torch from torch.nn.parameter import Parameter - from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.triton_utils.kernels import ( + QuantLinearInferenceOnlyFunction, ) + +logger = init_logger(__name__) + +try: + import autogptq_cuda_64 + import autogptq_cuda_256 + + _autogptq_cuda_available = True +except ImportError: + logger.warning("CUDA extension not installed.") + autogptq_cuda_256 = None + autogptq_cuda_64 = None + _autogptq_cuda_available = False + +# The bit width to which the quantized weight is packed needs to align with your quantization code. The quantized parameters, represented in lower precision, are packed in an int32 in GPTQ. +STORAGE_BITS_SIZE = 32 + + +class GPTQLinearKernel(Enum): + + TRITON = enum.auto() + EXLLAMA = enum.auto() + CUDA = enum.auto() + + +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() class GPTQConfig(QuantizationConfig): @@ -18,26 +48,33 @@ class GPTQConfig(QuantizationConfig): Reference: https://arxiv.org/abs/2210.17323 """ - def __init__( - self, - weight_bits: int, - group_size: int, - desc_act: bool, - ) -> None: + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + use_triton: bool, disable_exllama: bool) -> None: self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act - self.pack_factor = 32 // self.weight_bits - # exllama kernel v1 only supports 4 bit - if self.weight_bits != 4: + self.use_triton = use_triton + self.disable_exllama = disable_exllama + # The Exllama kernel only supports 4 bits. Under 4-bit quantization, it will be used if disable_exllama is False; + # otherwise, the Triton or CUDA kernel will be used for quantization precision other than 4 bits. + if self.weight_bits in [2, 4, 8]: + self.kernel_type = GPTQLinearKernel.TRITON if self.use_triton else GPTQLinearKernel.CUDA + if self.weight_bits == 4: + self.kernel_type = GPTQLinearKernel.EXLLAMA if not disable_exllama else self.kernel_type + elif self.weight_bits == 3: + self.kernel_type = GPTQLinearKernel.CUDA + else: raise ValueError( - "Currently, only 4-bit weight quantization is supported for " + "Currently, only 2, 3, 4, and 8-bit weight quantization is supported for" f"GPTQ, but got {self.weight_bits} bits.") + self.maxq = 2**self.weight_bits - 1 def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " - f"desc_act={self.desc_act})") + f"desc_act={self.desc_act}, " + f"use_triton={self.use_triton}, " + f"disable_exllama={self.disable_exllama}") @classmethod def get_name(cls) -> str: @@ -61,7 +98,10 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) - return cls(weight_bits, group_size, desc_act) + use_triton = cls.get_from_keys(config, ["use_triton"]) + disable_exllama = cls.get_from_keys(config, ["disable_exllama"]) + return cls(weight_bits, group_size, desc_act, use_triton, + disable_exllama) def get_linear_method(self) -> "GPTQLinearMethod": return GPTQLinearMethod(self) @@ -70,13 +110,6 @@ def get_scaled_act_names(self) -> List[str]: return [] -class ExllamaState(Enum): - - UNUSED = enum.auto() - UNINITIALIZED = enum.auto() - READY = enum.auto() - - class GPTQLinearMethod(LinearMethodBase): """Linear method for GPTQ. @@ -101,12 +134,36 @@ def create_weights( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") - if output_size_per_partition % self.quant_config.pack_factor != 0: + if output_size_per_partition % STORAGE_BITS_SIZE != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + if self.quant_config.kernel_type == GPTQLinearKernel.CUDA: + if self.quant_config.weight_bits in [2, 4, 8]: + self.wf = torch.tensor(list( + range(0, STORAGE_BITS_SIZE, + self.quant_config.weight_bits)), + dtype=torch.int32).unsqueeze(0) + elif self.quant_config.weight_bits == 3: + # Under 3-bit quantization, packing is different from other cases because many 3 bits won't fit neatly into a 32-bit width integer. + # There could be one bit that spans across two 32-bit vectors. + # Here, we use a tensor to track the offset of each bit in three 32-bit vectors, which can neatly accommodate 32 3-bit weights. + self.wf = torch.tensor([ + [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], + ], + dtype=torch.int32).reshape(1, 3, 12) + self.autogptq_cuda_available = _autogptq_cuda_available + + self.autogptq_cuda = autogptq_cuda_256 + if input_size_per_partition % 256 != 0 or output_size_per_partition % 256 != 0: + self.autogptq_cuda = autogptq_cuda_64 + if input_size_per_partition % 64 != 0 or output_size_per_partition % 64 != 0: + self.autogptq_cuda_available = False + if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: @@ -125,7 +182,8 @@ def create_weights( qweight = Parameter( torch.empty( - input_size_per_partition // self.quant_config.pack_factor, + input_size_per_partition // STORAGE_BITS_SIZE * + self.quant_config.weight_bits, output_size_per_partition, device="cuda", dtype=torch.int32, @@ -137,7 +195,8 @@ def create_weights( "input_dim": 0, "output_dim": 1, "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, + "weight_bits": self.quant_config.weight_bits, + "storage_bits_size": STORAGE_BITS_SIZE }) g_idx = Parameter( torch.tensor( @@ -155,7 +214,8 @@ def create_weights( qzeros = Parameter( torch.empty( scale_and_zero_size, - output_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition // STORAGE_BITS_SIZE * + self.quant_config.weight_bits, device="cuda", dtype=torch.int32, ), @@ -166,7 +226,8 @@ def create_weights( "input_dim": scale_and_zero_input_dim, "output_dim": 1, "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, + "weight_bits": self.quant_config.weight_bits, + "storage_bits_size": STORAGE_BITS_SIZE }) scales = Parameter( torch.empty( @@ -186,30 +247,158 @@ def create_weights( "g_idx": g_idx, "qzeros": qzeros, "scales": scales, - "exllama_state": exllama_state, + "exllama_state": + exllama_state, # when use_triton is true or quantization precision is not equal to 4-bit, exllama state will be ignored } def apply_weights(self, weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = weights["qweight"] - out_shape = x.shape[:-1] + (qweight.shape[-1], ) + scales = weights["scales"] + qzeros = weights["qzeros"] + out_shape = x.shape[:-1] + (weights["qweight"].shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) - # exllama needs to shuffle the weight after the weight is loaded - # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExllamaState.UNINITIALIZED: - if self.quant_config.desc_act: - weights["g_idx"] = torch.argsort(weights["g_idx"]).to( - torch.int) + if self.quant_config.kernel_type == GPTQLinearKernel.EXLLAMA: + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + weights["g_idx"] = torch.argsort(weights["g_idx"]).to( + torch.int) + else: + weights["g_idx"] = torch.empty((1, 1), device="meta") + weights["exllama_state"] = ExllamaState.READY + ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) + + output = ops.gptq_gemm( + reshaped_x, weights["qweight"], qzeros, scales, + weights["g_idx"], + weights["exllama_state"] == ExllamaState.READY) + elif self.quant_config.kernel_type == GPTQLinearKernel.TRITON: + quant_linear_fn = QuantLinearInferenceOnlyFunction + output = quant_linear_fn.apply(reshaped_x, weights["qweight"], + scales, qzeros, weights["g_idx"], + self.quant_config.weight_bits, + self.quant_config.maxq) + output = output.half().reshape(out_shape) + else: + self.kernel_switch_threshold = 128 + if reshaped_x.device.type == "cuda" and self.autogptq_cuda_available and ( + self.kernel_switch_threshold == 0 + or reshaped_x.shape[0] < self.kernel_switch_threshold): + output = torch.zeros(out_shape, + device=reshaped_x.device, + dtype=torch.float32) + if self.quant_config.weight_bits == 2: + self.autogptq_cuda.vecquant2matmul(reshaped_x.float(), + weights["qweight"], + output, scales.float(), + qzeros, + weights["g_idx"]) + elif self.quant_config.weight_bits == 3: + self.autogptq_cuda.vecquant3matmul(reshaped_x.float(), + weights["qweight"], + output, scales.float(), + qzeros, + weights["g_idx"]) + elif self.quant_config.weight_bits == 4: + self.autogptq_cuda.vecquant4matmul(reshaped_x.float(), + weights["qweight"], + output, scales.float(), + qzeros, + weights["g_idx"]) + elif self.quant_config.weight_bits == 8: + self.autogptq_cuda.vecquant8matmul(reshaped_x.float(), + weights["qweight"], + output, scales.float(), + qzeros, + weights["g_idx"]) + else: + raise NotImplementedError( + "Only 2,3,4,8 bits are supported.") else: - weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExllamaState.READY - ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) - output = ops.gptq_gemm(reshaped_x, weights["qweight"], - weights["qzeros"], weights["scales"], - weights["g_idx"], - weights["exllama_state"] == ExllamaState.READY) + if self.wf.device != qzeros.device: + self.wf = self.wf.to(qzeros.device) + + if self.quant_config.weight_bits in [2, 4, 8]: + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand( + -1, -1, 32 // self.quant_config.weight_bits), + self.wf.unsqueeze(0)).to( + torch.int16 if self.quant_config.weight_bits == + 8 else torch.int8) + zeros = torch.bitwise_and( + zeros, (2**self.quant_config.weight_bits) - 1) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(weights["qweight"], 1).expand( + -1, 32 // self.quant_config.weight_bits, -1), + self.wf.unsqueeze(-1)).to( + torch.int16 if self.quant_config.weight_bits == + 8 else torch.int8) + weight = torch.bitwise_and( + weight, (2**self.quant_config.weight_bits) - 1) + elif self.quant_config.weight_bits == 3: + zeros = qzeros.reshape(qzeros.shape[0], + qzeros.shape[1] // 3, 3, + 1).expand(-1, -1, -1, 12) + zeros = (zeros >> self.wf.unsqueeze(0)) + zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ( + (zeros[:, :, 1, 0] << 2) & 0x4) + zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ( + (zeros[:, :, 2, 0] << 1) & 0x6) + zeros = zeros & 0x7 + zeros = torch.cat([ + zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], + zeros[:, :, 2, 1:11] + ], + dim=2) + + zeros = zeros + 1 + zeros = zeros.reshape(scales.shape) + + weight = weights["qweight"].reshape( + weights["qweight"].shape[0] // 3, 3, 1, + weights["qweight"].shape[1]).expand(-1, -1, 12, -1) + weight = (weight >> self.wf.unsqueeze(-1)) & 0x7 + weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ( + (weight[:, 1, 0] << 2) & 0x4) + weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ( + (weight[:, 2, 0] << 1) & 0x6) + weight = weight & 0x7 + weight = torch.cat([ + weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, + 1:11] + ], + dim=1) + else: + raise NotImplementedError( + "Only 2,3,4,8 bits are supported.") + + weight = weight.reshape(weight.shape[0] * weight.shape[1], + weight.shape[2]) + num_itr = weights["g_idx"].shape[0] // x.shape[-1] + if num_itr == 1: + weights = (scales[weights["g_idx"].long()] * + (weight - zeros[weights["g_idx"].long()])) + else: + num_dim = weights["g_idx"].shape[0] // num_itr + weights = [] + for i in range(num_itr): + scale_i = scales[:, i * num_dim:(i + 1) * num_dim] + weight_i = weight[:, i * num_dim:(i + 1) * num_dim] + zeros_i = zeros[:, i * num_dim:(i + 1) * num_dim] + g_idx_i = weights["g_idx"][i * num_dim:(i + 1) * + num_dim] + weights.append(scale_i[g_idx_i.long()] * + (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights, dim=1) + output = torch.matmul(x, weights) + output = output.to(x.dtype) if bias is not None: output = output + bias return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 1932bd145076..a92a208775d7 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -9,6 +9,8 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.utils import is_hip +STORAGE_BITS_SIZE = 32 + class SqueezeLLMConfig(QuantizationConfig): """Config class for SqueezeLLM. @@ -27,7 +29,7 @@ def __init__( "Currently, only 4-bit weight quantization is supported for " f"SqueezeLLM, but got {self.weight_bits} bits.") - self.pack_factor = 32 // self.weight_bits + self.pack_factor = STORAGE_BITS_SIZE // self.weight_bits def __repr__(self) -> str: return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" @@ -71,14 +73,15 @@ def create_weights(self, input_size_per_partition: int, output_size_per_partition: int, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: - if input_size_per_partition % self.quant_config.pack_factor != 0: + if input_size_per_partition % STORAGE_BITS_SIZE != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") qweight = Parameter( torch.empty( - input_size_per_partition // self.quant_config.pack_factor, + input_size_per_partition // STORAGE_BITS_SIZE * + self.quant_config.weight_bits, output_size_per_partition, device="cuda", dtype=torch.int32, @@ -90,7 +93,8 @@ def create_weights(self, input_size_per_partition: int, "input_dim": 0, "output_dim": 1, "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, + "weight_bits": self.quant_config.weight_bits, + "storage_bits_size": STORAGE_BITS_SIZE, }) lookup_table = Parameter( torch.empty( diff --git a/vllm/model_executor/layers/quantization/triton_utils/__init__.py b/vllm/model_executor/layers/quantization/triton_utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/model_executor/layers/quantization/triton_utils/custom_autotune.py b/vllm/model_executor/layers/quantization/triton_utils/custom_autotune.py new file mode 100644 index 000000000000..bf89b5dea184 --- /dev/null +++ b/vllm/model_executor/layers/quantization/triton_utils/custom_autotune.py @@ -0,0 +1,198 @@ +import builtins +import math +import time +from typing import Dict + +import triton + +# code based https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + + +class CustomizedTritonAutoTuner(triton.KernelInterface): + + def __init__(self, + fn, + arg_names, + configs, + key, + reset_to_zero, + prune_configs_by: Dict = None, + nearest_power_of_two: bool = False): + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by[ + 'perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **current) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench(kernel_call, + quantiles=(0.5, 0.2, 0.8), + rep=40) + except triton.OutOfResources: + return (float('inf'), float('inf'), float('inf')) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2**int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), + key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, + key, + prune_configs_by=None, + reset_to_zero=None, + nearest_power_of_two=False): + + def decorator(fn): + return CustomizedTritonAutoTuner(fn, fn.arg_names, configs, key, + reset_to_zero, prune_configs_by, + nearest_power_of_two) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) + n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) + k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) + block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) + block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) + group_size_m = config.kwargs['GROUP_SIZE_M'] + + if (block_size_m, block_size_n, block_size_k, group_size_m, + config.num_stages, config.num_warps) in used: + continue + + used.add((block_size_m, block_size_n, block_size_k, group_size_m, + config.num_stages, config.num_warps)) + yield triton.Config( + { + 'BLOCK_SIZE_M': block_size_m, + 'BLOCK_SIZE_N': block_size_n, + 'BLOCK_SIZE_K': block_size_k, + 'GROUP_SIZE_M': group_size_m + }, + num_stages=config.num_stages, + num_warps=config.num_warps) + + +__all__ = ["autotune"] diff --git a/vllm/model_executor/layers/quantization/triton_utils/kernels.py b/vllm/model_executor/layers/quantization/triton_utils/kernels.py new file mode 100644 index 000000000000..35241af6c3d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/triton_utils/kernels.py @@ -0,0 +1,188 @@ +import torch +from torch.cuda.amp import custom_fwd +from logging import getLogger + +import triton +import triton.language as tl + +from . import custom_autotune + +logger = getLogger(__name__) + +# Adapted from https://github.com/fpgaminer/GPTQ-triton and https://github.com/PanQiWei/AutoGPTQ + + +@custom_autotune.autotune( + configs=[ + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=4, + num_warps=4), + triton.Config( + { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, + num_stages=2, + num_warps=8) + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, +) +@triton.jit +def quant_matmul_248_kernel( + a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + ( + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ( + (offs_k[:, None] // infearure_per_bits) * stride_bk + + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for _ in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load( + scales_ptrs + + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load( + zeros_ptrs + + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, + other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, + None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty((input.shape[0], qweight.shape[1]), + device=input.device, + dtype=input.dtype) + grid = lambda META: (triton.cdiv( + input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv( + qweight.shape[1], META['BLOCK_SIZE_N']), ) + quant_matmul_248_kernel[grid](input, qweight, output, + scales.to(input.dtype), qzeros, g_idx, + input.shape[0], qweight.shape[1], + input.shape[1], bits, maxq, + input.stride(0), input.stride(1), + qweight.stride(0), qweight.stride(1), + output.stride(0), output.stride(1), + scales.stride(0), qzeros.stride(0)) + return output + + +class QuantLinearInferenceOnlyFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, + maxq) + return output