From 82e6b2e139ba9706fb8e1435a169faf5e0ec19ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 18 Sep 2023 18:40:52 +0800 Subject: [PATCH 01/18] Add gptq implementation compatible with awq interface --- csrc/quantization.cpp | 18 +- csrc/quantization/gptq/cu_compat.cuh | 58 ++++ csrc/quantization/gptq/cuda_buffers.cu | 75 +++++ csrc/quantization/gptq/cuda_buffers.cuh | 55 ++++ .../gptq/cuda_func/column_remap.cu | 63 ++++ .../gptq/cuda_func/column_remap.cuh | 19 ++ csrc/quantization/gptq/cuda_func/q4_matmul.cu | 260 ++++++++++++++++ .../quantization/gptq/cuda_func/q4_matmul.cuh | 43 +++ csrc/quantization/gptq/cuda_func/q4_matrix.cu | 225 ++++++++++++++ .../quantization/gptq/cuda_func/q4_matrix.cuh | 53 ++++ csrc/quantization/gptq/exllama_ext.cpp | 244 +++++++++++++++ csrc/quantization/gptq/hip_compat.cuh | 49 +++ csrc/quantization/gptq/matrix.cuh | 294 ++++++++++++++++++ csrc/quantization/gptq/tuning.h | 13 + csrc/quantization/gptq/util.cuh | 33 ++ setup.py | 8 +- vllm/config.py | 8 +- .../layers/quantized_linear/__init__.py | 3 + .../layers/quantized_linear/gptq.py | 184 +++++++++++ .../layers/quantized_linear/utils.py | 93 ++++++ vllm/model_executor/model_loader.py | 19 +- vllm/model_executor/models/llama.py | 37 ++- .../parallel_utils/tensor_parallel/layers.py | 1 + .../quantization_utils/__init__.py | 2 + vllm/model_executor/quantization_utils/awq.py | 6 +- .../model_executor/quantization_utils/base.py | 9 +- .../model_executor/quantization_utils/gptq.py | 77 +++++ vllm/model_executor/weight_utils.py | 6 + vllm/worker/worker.py | 3 +- 29 files changed, 1935 insertions(+), 23 deletions(-) create mode 100644 csrc/quantization/gptq/cu_compat.cuh create mode 100644 csrc/quantization/gptq/cuda_buffers.cu create mode 100644 csrc/quantization/gptq/cuda_buffers.cuh create mode 100644 csrc/quantization/gptq/cuda_func/column_remap.cu create mode 100644 csrc/quantization/gptq/cuda_func/column_remap.cuh create mode 100644 csrc/quantization/gptq/cuda_func/q4_matmul.cu create mode 100644 csrc/quantization/gptq/cuda_func/q4_matmul.cuh create mode 100644 csrc/quantization/gptq/cuda_func/q4_matrix.cu create mode 100644 csrc/quantization/gptq/cuda_func/q4_matrix.cuh create mode 100644 csrc/quantization/gptq/exllama_ext.cpp create mode 100644 csrc/quantization/gptq/hip_compat.cuh create mode 100644 csrc/quantization/gptq/matrix.cuh create mode 100644 csrc/quantization/gptq/tuning.h create mode 100644 csrc/quantization/gptq/util.cuh create mode 100644 vllm/model_executor/layers/quantized_linear/gptq.py create mode 100644 vllm/model_executor/layers/quantized_linear/utils.py create mode 100644 vllm/model_executor/quantization_utils/gptq.py diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index 3afa7f6a231d..ae8ff8b84dc3 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -1,3 +1,4 @@ +#include #include torch::Tensor awq_gemm( @@ -7,9 +8,24 @@ torch::Tensor awq_gemm( torch::Tensor _zeros, int split_k_iters); +void gptq_set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap, + bool matmul_no_half2); + +void gptq_prepare_buffers(torch::Device device, torch::Tensor temp_state, + torch::Tensor temp_dq); + +uintptr_t gptq_make_q4(torch::Tensor qweight, torch::Tensor qzeros, + torch::Tensor scales, torch::Tensor g_idx, int device); + +void gptq_q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); -} + m.def("gptq_set_tuning_params", &gptq_set_tuning_params, "gptq_set_tuning_params"); + m.def("gptq_prepare_buffers", &gptq_prepare_buffers, "gptq_prepare_buffers"); + m.def("gptq_make_q4", &gptq_make_q4, "gptq_make_q4"); + m.def("gptq_q4_matmul", &gptq_q4_matmul, "gptq_q4_matmul"); +} \ No newline at end of file diff --git a/csrc/quantization/gptq/cu_compat.cuh b/csrc/quantization/gptq/cu_compat.cuh new file mode 100644 index 000000000000..c5258813e147 --- /dev/null +++ b/csrc/quantization/gptq/cu_compat.cuh @@ -0,0 +1,58 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_compat_cuh +#define _cuda_compat_cuh + +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +#endif diff --git a/csrc/quantization/gptq/cuda_buffers.cu b/csrc/quantization/gptq/cuda_buffers.cu new file mode 100644 index 000000000000..4416027c8387 --- /dev/null +++ b/csrc/quantization/gptq/cuda_buffers.cu @@ -0,0 +1,75 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#define _cuda_buffers_cu +#include "cuda_buffers.cuh" + +CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; +// __constant__ half2 q4_table[16][256]; +// half2 q4_table_host[16][256]; +// bool q4_table_init = false; + +CudaBuffers::CudaBuffers +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) : + device(_device), + temp_state_size(_temp_state_size), + temp_state(_temp_state), + temp_dq(_temp_dq) +{ + cudaSetDevice(_device); + + cudaStreamCreate(&alt_stream_1); + cudaStreamCreate(&alt_stream_2); + cudaStreamCreate(&alt_stream_3); + cudaEventCreate(&alt_stream_1_done); + cudaEventCreate(&alt_stream_2_done); + cudaEventCreate(&alt_stream_3_done); +} + +CudaBuffers::~CudaBuffers() +{ + cudaStreamDestroy(alt_stream_1); + cudaStreamDestroy(alt_stream_2); + cudaStreamDestroy(alt_stream_3); + cudaEventDestroy(alt_stream_1_done); + cudaEventDestroy(alt_stream_2_done); + cudaEventDestroy(alt_stream_3_done); +} + +CudaBuffers* get_buffers(const int device_index) +{ + return g_buffers[device_index]; +} + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +) +{ + CudaBuffers* buffers = new CudaBuffers + ( + _device, + _temp_state_size, + _temp_state, + _temp_dq + ); + + g_buffers[_device] = buffers; +} + +void cleanup_buffers_cuda() +{ + for (int i = 0; i < CUDA_MAX_DEVICES; i++) + { + if (!g_buffers[i]) continue; + delete g_buffers[i]; + g_buffers[i] = NULL; + } +} diff --git a/csrc/quantization/gptq/cuda_buffers.cuh b/csrc/quantization/gptq/cuda_buffers.cuh new file mode 100644 index 000000000000..0bf2057c665c --- /dev/null +++ b/csrc/quantization/gptq/cuda_buffers.cuh @@ -0,0 +1,55 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _cuda_buffers_cuh +#define _cuda_buffers_cuh + +#include +#include +#include +#include + +const int CUDA_MAX_DEVICES = 16; + +// #ifndef _cuda_buffers_cu +// extern __constant__ half2 q4_table[16][256]; +// #endif + +class CudaBuffers +{ +public: + int device; + + half* temp_state; // [max_hidden_rows * intermediate_size] + int temp_state_size; + half* temp_dq; // size of largest quant tensor * 8 + + cudaStream_t alt_stream_1; + cudaStream_t alt_stream_2; + cudaStream_t alt_stream_3; + cudaEvent_t alt_stream_1_done; + cudaEvent_t alt_stream_2_done; + cudaEvent_t alt_stream_3_done; + + CudaBuffers + ( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq + ); + ~CudaBuffers(); +}; + +CudaBuffers* get_buffers(const int device_index); + +void prepare_buffers_cuda +( + int _device, + int _temp_state_size, + half* _temp_state, + half* _temp_dq +); + +void cleanup_buffers_cuda(); + +#endif diff --git a/csrc/quantization/gptq/cuda_func/column_remap.cu b/csrc/quantization/gptq/cuda_func/column_remap.cu new file mode 100644 index 000000000000..30e4039dd2e9 --- /dev/null +++ b/csrc/quantization/gptq/cuda_func/column_remap.cu @@ -0,0 +1,63 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "column_remap.cuh" +#include "../util.cuh" + +const int SHUF_BLOCKSIZE_X = 256; +const int SHUF_BLOCKSIZE_Y = 16; + +__global__ void column_remap_kernel +( + const half* __restrict__ x, + half* __restrict__ x_new, + const int x_width, + const int x_height, + const uint32_t* x_map +) +{ + int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; + if (x_column >= x_width) return; + //if (x_row >= x_height) return; + + int x_stride = x_width; + int x_idx = x_row * x_stride + x_column; + + int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); + int x_idx_end = x_row_end * x_stride + x_column; + + int s_column = x_map[x_column]; + int s_idx = x_row * x_stride + s_column; + + while (x_idx < x_idx_end) + { + x_new[x_idx] = x[s_idx]; + x_idx += x_stride; + s_idx += x_stride; + } +} + +// Remap columns in x to correspond to sequential group index before matmul +// +// perform x -> seq_x such that seq_x @ seq_w == x @ w + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +) +{ + dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); + + dim3 blocks + ( + (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, + (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, + 1 + ); + + column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); +} diff --git a/csrc/quantization/gptq/cuda_func/column_remap.cuh b/csrc/quantization/gptq/cuda_func/column_remap.cuh new file mode 100644 index 000000000000..6571c17d6fd5 --- /dev/null +++ b/csrc/quantization/gptq/cuda_func/column_remap.cuh @@ -0,0 +1,19 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _column_remap_cuh +#define _column_remap_cuh + +#include +#include +#include + +void column_remap_cuda +( + const half* x, + half* x_new, + const int x_height, + const int x_width, + const uint32_t* x_map +); + +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq/cuda_func/q4_matmul.cu b/csrc/quantization/gptq/cuda_func/q4_matmul.cu new file mode 100644 index 000000000000..0ee6e16dc862 --- /dev/null +++ b/csrc/quantization/gptq/cuda_func/q4_matmul.cu @@ -0,0 +1,260 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matmul.cuh" +#include "column_remap.cuh" +#include "../util.cuh" +#include "../matrix.cuh" +#include "../cu_compat.cuh" +#include "../cuda_buffers.cuh" +#if defined(USE_ROCM) +#include "../hip_compat.cuh" +#endif + +const int THREADS_X = 32; // Block size and thread count along columns in w and out +const int THREADS_Y = 1; // Block size and thread count along rows in x and out + +typedef void (*fp_q4_matmul_kernel) +( + const half*, + const uint32_t*, + half*, + const half*, + const uint32_t*, + const int, + const int, + const int, + const int, + const int, + const uint32_t*, + bool +); + +template +__global__ void q4_matmul_kernel +( + const half* __restrict__ x, + const uint32_t* __restrict__ w, + half* __restrict__ out, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int dim, + const int width, + const int groupsize, + const int block_size_z, + const uint32_t* __restrict__ x_map, + bool no_zero +) +{ + // Start of block + + int x_column = block_size_z * blockIdx.z; + int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); + + int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int x_row = THREADS_Y * blockIdx.y + threadIdx.y; + + int iterations = (x_column_end - x_column) / 8; + + // Views + + MatrixView_half x_(x, height, dim); + MatrixView_half w_scales_(w_scales, dim / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); + MatrixView_q4_column w_(w, dim, width); + MatrixView_half_rw out_(out, height, width); + + // Zero output + + if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) + { + *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); + } + + // Loop over part of x row (and w column) + + half2 acc = {}; + half acc_h = {}; + + if constexpr (use_groupsize) + { + // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this + // could be slightly faster + + for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) + { + if constexpr (use_half2) + { + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + else + { + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + } + } + } + else + { + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + + for (int k = x_column; k < x_column + iterations * 8; k += 8) + { + if constexpr (use_half2) + { + int group = k / groupsize; + half2 w_scale = w_scales_.item_half2half2(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + else + { + int group = k / groupsize; + half w_scale = w_scales_.item(group, w_column); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + } + } + } + + // Add to block result + + if constexpr (use_half2) + { + half result = __hadd(__low2half(acc), __high2half(acc)); + atomicAdd(out_.item_ptr(x_row, w_column), result); + } + else + { + atomicAdd(out_.item_ptr(x_row, w_column), acc_h); + } +} + +fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) +{ + // + if (tuningParams->matmul_no_half2) { + if (block_size_z % groupsize == 0) { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } else { + if (block_size_z % groupsize == 0) + { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } else { + if (x_map) return q4_matmul_kernel; + else return q4_matmul_kernel; + } + } +}; + +// Compute y = x @ w + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero, + cudaStream_t alt_stream +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + + uint32_t* x_map = w->cuda_x_map; + const half* x_mapped = x; + if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) + { + CudaBuffers* buffers = get_buffers(w->device); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + x_map = NULL; + } + + int block_size_z; + if (w->width == 4096) block_size_z = 384; // 7B + else if (w->width == 11008) block_size_z = 256; + else if (w->width == 5120) block_size_z = 384; // 13B + else if (w->width == 13824) block_size_z = 256; + else if (w->width == 6656) block_size_z = 256; // 33B + else if (w->width == 17920) block_size_z = 128; + else block_size_z = 256; + + //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); + + dim3 threads(THREADS_X, THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height + threads.y - 1) / threads.y, + (dim + block_size_z - 1) / block_size_z + ); + + fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); +} + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero +) +{ + int height = x_height; + int dim = w->height; + int width = w->width; + + cudaSetDevice(w->device); + CudaBuffers* buffers = get_buffers(w->device); + + const half* x_mapped = x; + if (w->cuda_x_map) + { + TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)"); + column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); + x_mapped = buffers->temp_state; + } + + w->reconstruct(buffers->temp_dq); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 + const float alpha = 1.0f; + const float beta = no_zero ? 1.0f : 0.0f; + cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, + x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); +#else + const half alpha = __float2half(1.0f); + const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); + cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); +#endif +} diff --git a/csrc/quantization/gptq/cuda_func/q4_matmul.cuh b/csrc/quantization/gptq/cuda_func/q4_matmul.cuh new file mode 100644 index 000000000000..49967648f2fd --- /dev/null +++ b/csrc/quantization/gptq/cuda_func/q4_matmul.cuh @@ -0,0 +1,43 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matmul_cuh +#define _q4_matmul_cuh + +#include +#include +#include +#include +#include + +#include "q4_matrix.cuh" +#include "../tuning.h" + +// Workaround for hipify_python using rocblas instead of hipblas. +#if defined(USE_ROCM) +#include +#define rocblas_handle hipblasHandle_t +#endif + +void q4_matmul_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + const Q4Matrix* w, + half* out, + bool no_zero = false, + cudaStream_t alt_stream = NULL +); + +void q4_matmul_recons_cuda +( + ExLlamaTuning* tuningParams, + const half* x, + const int x_height, + Q4Matrix* w, + half* out, + const cublasHandle_t handle, + bool no_zero = false +); + +#endif diff --git a/csrc/quantization/gptq/cuda_func/q4_matrix.cu b/csrc/quantization/gptq/cuda_func/q4_matrix.cu new file mode 100644 index 000000000000..2b3600e0fbc2 --- /dev/null +++ b/csrc/quantization/gptq/cuda_func/q4_matrix.cu @@ -0,0 +1,225 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include "q4_matrix.cuh" +#include +#include "../util.cuh" +#include "../matrix.cuh" + +using namespace std; + +const int UNSHUF_BLOCKSIZE_X = 64; + +const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column +const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows + +vector g_q4_matrices; + +void g_q4_keep_matrix(Q4Matrix* m) +{ + g_q4_matrices.push_back(m); +} + +void g_q4_free_matrices() +{ + for (const auto& m : g_q4_matrices) delete m; + g_q4_matrices.clear(); +} + +Q4Matrix::Q4Matrix +( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device +) : + height(_height), + width(_width), + groups(_groups), + device(_device) +{ + cudaSetDevice(device); + + cuda_qweight = _qweight; + cuda_qzeros = _qzeros; + cuda_scales = _scales; + + groupsize = height / groups; + + if (_g_idx) make_sequential(_g_idx); +} + +Q4Matrix::~Q4Matrix() +{ +} + +// Make sequential + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint32_t* __restrict__ x_map, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int x_map_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = x_map[x_map_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Move to CUDA + + cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); + dim3 blocks + ( + (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), + height / 8, + 1 + ); + + make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); + + // Replace qweights + + cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} + +__global__ void reconstruct_kernel +( + const uint32_t* __restrict__ w, + half* __restrict__ out, // (y) + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int height, + const int width, + const int groupsize +) +{ + // Start of block + + int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; + int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, height / groupsize, width); + MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); + + // Groupsize version + + int group = row / groupsize; + + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + +void Q4Matrix::reconstruct(half* out) +{ + dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); + + dim3 blocks + ( + (width + threads.x - 1) / threads.x, + (height / 8 + threads.y - 1) / threads.y, + 1 + ); + + reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); +} \ No newline at end of file diff --git a/csrc/quantization/gptq/cuda_func/q4_matrix.cuh b/csrc/quantization/gptq/cuda_func/q4_matrix.cuh new file mode 100644 index 000000000000..50cb72a41518 --- /dev/null +++ b/csrc/quantization/gptq/cuda_func/q4_matrix.cuh @@ -0,0 +1,53 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _q4_matrix_cuh +#define _q4_matrix_cuh + +#include +#include +#include + +class Q4Matrix +{ +public: + + int device; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_qweight = NULL; + uint32_t* cuda_qzeros = NULL; + half* cuda_scales = NULL; + uint32_t* cuda_x_map = NULL; + + Q4Matrix + ( + const int _height, + const int _width, + const int _groups, + + uint32_t* _qweight, + uint32_t* _qzeros, + half* _scales, + uint32_t* _g_idx, + + const int _device + ); + + ~Q4Matrix(); + + void reconstruct(half* out); + +private: + + void make_sequential(const uint32_t* cpu_g_idx); + +}; + +void g_q4_keep_matrix(Q4Matrix* m); +void g_q4_free_matrices(); + +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq/exllama_ext.cpp b/csrc/quantization/gptq/exllama_ext.cpp new file mode 100644 index 000000000000..3332d6955eba --- /dev/null +++ b/csrc/quantization/gptq/exllama_ext.cpp @@ -0,0 +1,244 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#include +#include +#include +#include +#include +#include +#include +#include "util.cuh" +#include "tuning.h" +#include "cuda_buffers.cuh" +#include "cuda_func/q4_matrix.cuh" +#include "cuda_func/q4_matmul.cuh" +#include "cuda_func/column_remap.cuh" + +// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a +// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of +// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. + +void check_cuda(cudaError_t ret) +{ + switch (ret) + { + case cudaSuccess: + break; + + case cudaUnspecified: + printf(" **** Unspecified error\n"); + TORCH_CHECK(false, "CUDA error"); + break; + + default: + printf(" **** CUDA error\n"); \ + printf(" **** %s\n", cudaGetErrorString(ret)); \ + TORCH_CHECK(false, "CUDA error"); \ + break; + } +} + +// Some decluttering macros + +#define STRINGIFY_(__x) #__x +#define STRINGIFY(__x) STRINGIFY_(__x) +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) +#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") + +#define TORCH_CHECK_DEVICE_INDEX(__index) \ +do { \ + TORCH_CHECK(__index >= 0, "no device index"); \ + TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ +} while(0) + +#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ +do { \ + TORCH_CHECK_DTYPE(__w, kInt); \ + TORCH_CHECK_DTYPE(__w_scales, kHalf); \ + TORCH_CHECK_DTYPE(__w_zeros, kInt); \ + TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ + TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ + TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ + TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ +} while(0) + +int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) +{ + int groupsize = w.size(0) * 8 / w_zeros.size(0); + TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") + return groupsize; +} + + +// Tuning parameters + +ExLlamaTuning tuningParams; + +void gptq_set_tuning_params +( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2 +) +{ + tuningParams.matmul_recons_thd = matmul_recons_thd; + tuningParams.matmul_fused_remap = matmul_fused_remap; + tuningParams.matmul_no_half2 = matmul_no_half2; +} + + +// Release all unmanaged objects allocated by the extension + +void gptq_cleanup() +{ + cleanup_buffers_cuda(); + g_q4_free_matrices(); +} + + +// Prepare buffers for forward pass + +void gptq_prepare_buffers +( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq +) +{ + int device_index = device.index(); + TORCH_CHECK_DEVICE_INDEX(device_index); + const at::cuda::OptionalCUDAGuard device_guard(device); + + prepare_buffers_cuda + ( + device_index, + // buffer size used for sanity checks + temp_state.numel(), + (half*) temp_state.data_ptr(), + (half*) temp_dq.data_ptr() + ); +} + + +// Create Q4Matrix, return handle + +uintptr_t gptq_make_q4 +( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device +) +{ + TORCH_CHECK_DTYPE(qweight, kInt); + TORCH_CHECK_DTYPE(qzeros, kInt); + TORCH_CHECK_DTYPE(scales, kHalf); + TORCH_CHECK_DTYPE_OPT(g_idx, kInt); + TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); + TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); + TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); + + int width = qweight.size(1); + int height = qweight.size(0) * 8; + int groups = qzeros.size(0); + + Q4Matrix* m = new Q4Matrix + ( + height, + width, + groups, + + (uint32_t*) qweight.data_ptr(), + (uint32_t*) qzeros.data_ptr(), + (half*) scales.data_ptr(), + g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), + + device + ); + + g_q4_keep_matrix(m); + return reinterpret_cast (m); +} + + +// Matmul half @ quant -> half + +void gptq_q4_matmul +( + torch::Tensor x, + uintptr_t w, + torch::Tensor out +) +{ + Q4Matrix* wm = reinterpret_cast (w); + + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(out, kHalf); + TORCH_CHECK_SHAPES(x, 0, out, 0, 1); + TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + int x_height = x.size(0); + + if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) + { + q4_matmul_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr() + ); + } + else + { + q4_matmul_recons_cuda + ( + &tuningParams, + (half*) x.data_ptr(), + x_height, + wm, + (half*) out.data_ptr(), + at::cuda::getCurrentCUDABlasHandle() + ); + } +} + + +// Remap columns in half tensor + +void gptq_column_remap +( + torch::Tensor x, + torch::Tensor x_new, + torch::Tensor x_map +) +{ + TORCH_CHECK_DTYPE(x, kHalf); + TORCH_CHECK_DTYPE(x_new, kHalf); + TORCH_CHECK_DTYPE(x_map, kInt); + TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); + + int height = x.size(0); + int width = x.size(1); + + TORCH_CHECK_BUFFER_SIZE(x_new, height * width); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); + + column_remap_cuda + ( + (half*) x.data_ptr(), + (half*) x_new.data_ptr(), + height, + width, + (uint32_t*) x_map.data_ptr() + ); +} diff --git a/csrc/quantization/gptq/hip_compat.cuh b/csrc/quantization/gptq/hip_compat.cuh new file mode 100644 index 000000000000..5cd2e8553ef6 --- /dev/null +++ b/csrc/quantization/gptq/hip_compat.cuh @@ -0,0 +1,49 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _hip_compat_cuh +#define _hip_compat_cuh + +// Workaround for a bug in hipamd, backported from upstream. +__device__ __forceinline__ __half __compat_hrcp(__half x) { + return __half_raw{ + static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; +} + +__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { + return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), + static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; +} + +#define hrcp __compat_hrcp +#define h2rcp __compat_h2rcp + +// Workaround for hipify_python using rocblas instead of hipblas. +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} + +#define rocblas_handle hipblasHandle_t +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_get_stream hipblasGetStream +#define rocblas_set_stream hipblasSetStream +#define rocblas_hgemm __compat_hipblasHgemm + +#endif diff --git a/csrc/quantization/gptq/matrix.cuh b/csrc/quantization/gptq/matrix.cuh new file mode 100644 index 000000000000..2fd5ab0b36cd --- /dev/null +++ b/csrc/quantization/gptq/matrix.cuh @@ -0,0 +1,294 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _matrix_cuh +#define _matrix_cuh + +#include +#include + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale + +__device__ __forceinline__ half2 dot_product_8 +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + +// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) +// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; +// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; +// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; + + half2 tmp = __hmul2(*h_ptr++, v_01); + tmp = __hfma2(*h_ptr++, v_23, tmp); + tmp = __hfma2(*h_ptr++, v_45, tmp); + tmp = __hfma2(*h_ptr++, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count +) +{ + const half* h_ptr = h_.item_ptr(h_row, h_column); + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(*h_ptr++, v_0); + tmp = __hfma(*h_ptr++, v_1, tmp); + tmp = __hfma(*h_ptr++, v_2, tmp); + tmp = __hfma(*h_ptr++, v_3, tmp); + tmp = __hfma(*h_ptr++, v_4, tmp); + tmp = __hfma(*h_ptr++, v_5, tmp); + tmp = __hfma(*h_ptr++, v_6, tmp); + tmp = __hfma(*h_ptr++, v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + +#endif diff --git a/csrc/quantization/gptq/tuning.h b/csrc/quantization/gptq/tuning.h new file mode 100644 index 000000000000..770ca46aa7c8 --- /dev/null +++ b/csrc/quantization/gptq/tuning.h @@ -0,0 +1,13 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _tuning_h +#define _tuning_h + +struct ExLlamaTuning +{ + int matmul_recons_thd; + bool matmul_fused_remap; + bool matmul_no_half2; +}; + +#endif diff --git a/csrc/quantization/gptq/util.cuh b/csrc/quantization/gptq/util.cuh new file mode 100644 index 000000000000..7b397573214b --- /dev/null +++ b/csrc/quantization/gptq/util.cuh @@ -0,0 +1,33 @@ +// Adapted from turboderp exllama: https://github.com/turboderp/exllama + +#ifndef _util_cuh +#define _util_cuh + +#include +#include +#include +#include + +#if defined(USE_ROCM) +#define cudaUnspecified hipErrorUnknown +#else +#define cudaUnspecified cudaErrorApiFailureBase +#endif + +// React to failure on return code != cudaSuccess + +#define _cuda_check(fn) \ +do { \ + {_cuda_err = fn;} \ + if (_cuda_err != cudaSuccess) goto _cuda_fail; \ +} while(false) + +// React to failure on return code == 0 + +#define _alloc_check(fn) \ +do { \ + if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ + else _cuda_err = cudaSuccess; \ +} while(false) + +#endif diff --git a/setup.py b/setup.py index 047ee8d0e894..7026300fd3f9 100644 --- a/setup.py +++ b/setup.py @@ -150,8 +150,12 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: quantization_extension = CUDAExtension( name="vllm.quantization_ops", sources=[ - "csrc/quantization.cpp", - "csrc/quantization/awq/gemm_kernels.cu", + "csrc/quantization.cpp", "csrc/quantization/awq/gemm_kernels.cu", + "csrc/quantization/gptq/exllama_ext.cpp", + "csrc/quantization/gptq/cuda_buffers.cu", + "csrc/quantization/gptq/cuda_func/column_remap.cu", + "csrc/quantization/gptq/cuda_func/q4_matmul.cu", + "csrc/quantization/gptq/cuda_func/q4_matrix.cu" ], extra_compile_args={ "cxx": CXX_FLAGS, diff --git a/vllm/config.py b/vllm/config.py index dd92fbccd899..318b3a04f5e2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2,6 +2,7 @@ import torch from transformers import PretrainedConfig +from transformers.utils.quantization_config import QuantizationMethod from vllm.logger import init_logger from vllm.transformers_utils.config import get_config @@ -106,7 +107,12 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq"] + supported_quantization = ["awq", "gptq"] + if hasattr(self.hf_config, "quantization_config" + ) and self.hf_config.quantization_config.get( + "quant_method") == QuantizationMethod.GPTQ: + self.quantization = "gptq" + if self.quantization is None: return quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index bcb9a54e7a2c..7b456caf4d78 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,10 +1,13 @@ from vllm.model_executor.layers.quantized_linear.awq import ( AWQColumnParallelLinear, AWQRowParallelLinear) +from vllm.model_executor.layers.quantized_linear.gptq import ( + GPTQColumnParallelLinear, GPTQRowParallelLinear) from vllm.model_executor.parallel_utils.tensor_parallel import ( ColumnParallelLinear, RowParallelLinear) _QUANTIZED_LINEAR_REGISTRY = { "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), + "gptq": (GPTQColumnParallelLinear, GPTQRowParallelLinear), } diff --git a/vllm/model_executor/layers/quantized_linear/gptq.py b/vllm/model_executor/layers/quantized_linear/gptq.py new file mode 100644 index 000000000000..c749a01436b1 --- /dev/null +++ b/vllm/model_executor/layers/quantized_linear/gptq.py @@ -0,0 +1,184 @@ +from typing import Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import quantization_ops +from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( + ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.tensor_parallel.mappings import ( + gather_from_tensor_model_parallel_region) + + +class GPTQColumnParallelLinear(ColumnParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert self.input_size % self.quant_config.pack_factor == 0 + assert (self.output_size_per_partition % + self.quant_config.pack_factor == 0) + group_size = self.quant_config.group_size if ( + self.quant_config.group_size != -1) else self.input_size + + self.qweight = Parameter( + torch.empty( + self.input_size // self.quant_config.pack_factor, + self.output_size_per_partition, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.qzeros = Parameter( + torch.empty( + self.input_size // group_size, + self.output_size_per_partition // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.scales = Parameter( + torch.empty( + self.input_size // group_size, + self.output_size_per_partition, + device="cuda", + dtype=dtype, + ), + requires_grad=False, + ) + self.g_idx = Parameter( + torch.tensor( + [i // group_size for i in range(self.input_size)], + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + + def post_init(self): + assert self.qweight.device.type == "cuda" + assert self.qweight.device.index is not None + + # make_q4 segfaults if g_idx is not on cpu in the act-order case. + # In the non act-order case, None needs to be passed for g_idx. + if not self.quant_config.desc_act: + g_idx = torch.empty((1, 1), device="meta") + else: + g_idx = self.g_idx.to("cpu") + self.q4 = quantization_ops.gptq_make_q4(self.qweight, self.qzeros, + self.scales, g_idx, + self.qweight.device.index) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + output = torch.empty((x.shape[0], self.qweight.shape[-1]), + dtype=torch.float16, + device=x.device) + quantization_ops.gptq_q4_matmul(reshaped_x, self.q4, output) + if bias is not None: + out = out + bias + return output.reshape(out_shape) + + +class GPTQRowParallelLinear(RowParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert (self.input_size_per_partition % + self.quant_config.pack_factor == 0) + assert self.output_size % self.quant_config.pack_factor == 0 + # Ignore tensor parallel when group_size != -1 and desc_act + if self.quant_config.desc_act and self.quant_config.group_size != -1: + self.input_size_per_partition = self.input_size + self.parallel = False + else: + self.parallel = True + group_size = self.quant_config.group_size if ( + self.quant_config.group_size != -1 + ) else self.input_size_per_partition + self.qweight = Parameter( + torch.empty( + self.input_size_per_partition // self.quant_config.pack_factor, + self.output_size, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.qzeros = Parameter( + torch.empty( + self.input_size_per_partition // group_size, + self.output_size // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.scales = Parameter( + torch.empty( + self.input_size_per_partition // group_size, + self.output_size, + device="cuda", + dtype=dtype, + ), + requires_grad=False, + ) + self.g_idx = Parameter( + torch.tensor( + [ + i // group_size + for i in range(self.input_size_per_partition) + ], + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + + def post_init(self): + assert self.qweight.device.type == "cuda" + assert self.qweight.device.index is not None + + # make_q4 segfaults if g_idx is not on cpu in the act-order case. + # In the non act-order case, None needs to be passed for g_idx. + if not self.quant_config.desc_act: + g_idx = torch.empty((1, 1), device="meta") + else: + g_idx = self.g_idx.to("cpu") + self.q4 = quantization_ops.gptq_make_q4(self.qweight, self.qzeros, + self.scales, g_idx, + self.qweight.device.index) + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + output = torch.empty((x.shape[0], self.qweight.shape[-1]), + dtype=torch.float16, + device=x.device) + quantization_ops.gptq_q4_matmul(reshaped_x, self.q4, output) + if self.quant_config.desc_act and self.quant_config.group_size != -1: + output = output / self.world_size + return output.reshape(out_shape) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.parallel: + return super().forward(input_) + if self.input_is_parallel: + input_ = gather_from_tensor_model_parallel_region(input_) + output_ = self.apply_weights(input_) + if not self.reduce_results and self.world_size > 1: + output_ = output_ / self.world_size + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/vllm/model_executor/layers/quantized_linear/utils.py b/vllm/model_executor/layers/quantized_linear/utils.py new file mode 100644 index 000000000000..aecf513c423a --- /dev/null +++ b/vllm/model_executor/layers/quantized_linear/utils.py @@ -0,0 +1,93 @@ +from typing import Optional + +import torch + +from vllm import quantization_ops +from vllm.model_executor.layers.quantized_linear.gptq import ( + GPTQColumnParallelLinear, GPTQRowParallelLinear) + + +def quant_post_init(model, max_input_length: Optional[int] = None): + """ + The max_input_length argument is specific to the exllama backend, + that requires to initialize a buffer temp_state. + """ + device_to_buffers_size = {} + + model_uses_exllama = False + use_act_order = False + for _, submodule in model.named_modules(): + if isinstance(submodule, + (GPTQColumnParallelLinear, GPTQRowParallelLinear)): + model_uses_exllama = True + device = submodule.qweight.device + if device not in device_to_buffers_size: + device_to_buffers_size[device] = { + "max_dq_buffer_size": 1, + "max_inner_outer_dim": 1 + } + + device_to_buffers_size[device]["max_dq_buffer_size"] = max( + device_to_buffers_size[device]["max_dq_buffer_size"], + submodule.qweight.numel() * 8) + + in_features = submodule.input_size if isinstance( + submodule, GPTQColumnParallelLinear + ) else submodule.input_size_per_partition + out_features = submodule.output_size_per_partition if isinstance( + submodule, GPTQColumnParallelLinear) else submodule.output_size + if submodule.quant_config.desc_act: + use_act_order = True + device_to_buffers_size[device]["max_inner_outer_dim"] = max( + device_to_buffers_size[device]["max_inner_outer_dim"], + in_features, out_features) + + if model_uses_exllama: + device_to_buffers = {} + max_input_len = max_input_length if use_act_order else 1 + for device, buffers_size in device_to_buffers_size.items(): + # The temp_state buffer is required to reorder X in the act-order + # case. The temp_dq buffer is required to dequantize weights when + # using cuBLAS, typically for the prefill. + device_to_buffers[device] = { + "temp_state": + torch.zeros( + (max_input_len, buffers_size["max_inner_outer_dim"]), + dtype=torch.float16, + device=device), + "temp_dq": + torch.zeros((1, buffers_size["max_dq_buffer_size"]), + dtype=torch.float16, + device=device), + "max_dq_buffer_size": + buffers_size["max_dq_buffer_size"], + "max_inner_outer_dim": + buffers_size["max_inner_outer_dim"], + } + + # Buffers need to be persistent to avoid any bug. + model.device_to_buffers = device_to_buffers + + for device, buffers in model.device_to_buffers.items(): + quantization_ops.gptq_prepare_buffers(device, + buffers["temp_state"], + buffers["temp_dq"]) + + # Using the default from exllama repo here. + matmul_recons_thd = 8 + matmul_fused_remap = False + matmul_no_half2 = False + quantization_ops.gptq_set_tuning_params(matmul_recons_thd, + matmul_fused_remap, + matmul_no_half2) + + # The buffers need to have been initialized first before calling + # make_q4. + for _, submodule in model.named_modules(): + if isinstance(submodule, + (GPTQColumnParallelLinear, GPTQRowParallelLinear)): + submodule.post_init() + + torch.cuda.empty_cache() + + return model diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 30d1620d110b..a65de9ce7853 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -10,6 +10,7 @@ from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) +from vllm.model_executor.layers.quantized_linear.utils import quant_post_init # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { @@ -32,9 +33,10 @@ } # FIXME(woosuk): Remove this once all models support quantization. -_MODEL_CLASSES_SUPPORT_QUANTIZATION = [ - LlamaForCausalLM, -] +_MODEL_CLASSES_SUPPORT_QUANTIZATION = { + "awq": [LlamaForCausalLM], + "gptq": [LlamaForCausalLM], +} @contextlib.contextmanager @@ -56,17 +58,19 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the quantization config. quant_config = None if model_config.quantization is not None: - if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION: + if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION[ + model_config.quantization]: raise ValueError( f"Quantization is not supported for {model_class}.") quant_config = get_quant_config(model_config.quantization, model_config.model, + model_config.hf_config, model_config.download_dir) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: @@ -78,7 +82,8 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: + if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION[ + model_config.quantization]: model = model_class(model_config.hf_config, quant_config) else: model = model_class(model_config.hf_config) @@ -92,4 +97,6 @@ def get_model(model_config: ModelConfig) -> nn.Module: model.load_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision) model = model.cuda() + if model_config.quantization is not None: + quant_post_init(model, max_tokens) return model.eval() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e87f0073c520..c081462f0026 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -44,7 +44,7 @@ from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, - hf_model_weights_iterator) + hf_model_weights_iterator, convert_pyslice_to_tensor) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -298,17 +298,23 @@ def load_weights(self, load_format: str = "auto", revision: Optional[str] = None): if self.quant_config is None: - weight_suffixes = ["weight"] + column_weight_suffixes = ["weight", "bias"] + row_weight_suffixes = ["weight"] + ignore_weight_suffixes = [] else: - weight_suffixes = self.quant_config.get_tp_tensor_names() + column_weight_suffixes = self.quant_config.get_column_tp_tensor_names( + ) + row_weight_suffixes = self.quant_config.get_row_tp_tensor_names() + ignore_weight_suffixes = self.quant_config.get_ignore_tensor_names( + ) column_parallel_weights: List[str] = [] for layer in self._column_parallel_layers: - for suffix in weight_suffixes: + for suffix in column_weight_suffixes: column_parallel_weights.append(f"{layer}.{suffix}") row_parallel_weights: List[str] = [] for layer in self._row_parallel_layers: - for suffix in weight_suffixes: + for suffix in row_weight_suffixes: row_parallel_weights.append(f"{layer}.{suffix}") tp_size = get_tensor_model_parallel_world_size() @@ -330,6 +336,8 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + if any(name.endswith(suffix) for suffix in ignore_weight_suffixes): + continue is_packed = False is_transposed = False @@ -337,13 +345,16 @@ def load_weights(self, is_packed = self.quant_config.is_packed(name) is_transposed = self.quant_config.is_transposed(name) if is_transposed: - loaded_weight = loaded_weight.T + loaded_weight = convert_pyslice_to_tensor(loaded_weight).T is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] + name = name.replace(weight_name, "qkv_proj") + if name not in state_dict: + continue + param = state_dict[name] if is_transposed: param = param.T @@ -351,6 +362,8 @@ def load_weights(self, shard_size //= self.quant_config.pack_factor offset //= self.quant_config.pack_factor + if "g_idx" in name: + break loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * (tensor_model_parallel_rank + 1)] @@ -367,10 +380,16 @@ def load_weights(self, for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] + name = name.replace(weight_name, "gate_up_proj") + if name not in state_dict: + continue + param = state_dict[name] if is_transposed: param = param.T + if "g_idx" in name: + break + shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * @@ -384,6 +403,8 @@ def load_weights(self, if is_gate_up_weight: continue + if name not in state_dict: + continue param = state_dict[name] if is_transposed: param = param.T diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py b/vllm/model_executor/parallel_utils/tensor_parallel/layers.py index bfaf9c5f7349..0c7c817e1831 100644 --- a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py +++ b/vllm/model_executor/parallel_utils/tensor_parallel/layers.py @@ -114,6 +114,7 @@ def __init__(self, num_embeddings: int, embedding_dim: int, *, self.num_embeddings_per_partition, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype)) + def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py index df67758f7110..1abe07fe2b6e 100644 --- a/vllm/model_executor/quantization_utils/__init__.py +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -1,10 +1,12 @@ from typing import Type from vllm.model_executor.quantization_utils.awq import AWQConfig +from vllm.model_executor.quantization_utils.gptq import GPTQConfig from vllm.model_executor.quantization_utils.base import QuantizationConfig _QUANTIZATION_REGISTRY = { "awq": AWQConfig, + "gptq": GPTQConfig, } diff --git a/vllm/model_executor/quantization_utils/awq.py b/vllm/model_executor/quantization_utils/awq.py index ed8987e15792..dc5eac3c70cb 100644 --- a/vllm/model_executor/quantization_utils/awq.py +++ b/vllm/model_executor/quantization_utils/awq.py @@ -62,6 +62,8 @@ def get_packed_tensor_names(cls) -> List[str]: def get_transposed_tensor_names(cls) -> List[str]: return ["qweight", "qzeros", "scales"] - @classmethod - def get_tp_tensor_names(cls) -> List[str]: + def get_row_tp_tensor_names(self) -> List[str]: return ["qweight", "qzeros", "scales"] + + def get_column_tp_tensor_names(self) -> List[str]: + return ["qweight", "qzeros", "scales", "bias"] diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/quantization_utils/base.py index cb406f4c98e1..493a631d93ab 100644 --- a/vllm/model_executor/quantization_utils/base.py +++ b/vllm/model_executor/quantization_utils/base.py @@ -60,6 +60,11 @@ def is_transposed(cls, tensor_name: str) -> bool: return any(tag in tensor_name for tag in cls.get_transposed_tensor_names()) - @classmethod - def get_tp_tensor_names(cls) -> List[str]: + def get_row_tp_tensor_names(self) -> List[str]: + raise NotImplementedError + + def get_column_tp_tensor_names(self) -> List[str]: raise NotImplementedError + + def get_ignore_tensor_names(self) -> List[str]: + return [] diff --git a/vllm/model_executor/quantization_utils/gptq.py b/vllm/model_executor/quantization_utils/gptq.py new file mode 100644 index 000000000000..8825efec388d --- /dev/null +++ b/vllm/model_executor/quantization_utils/gptq.py @@ -0,0 +1,77 @@ +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.quantization_utils.base import QuantizationConfig + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.pack_factor = 32 // self.weight_bits + # exllama kernel v1 only supports 4 bit + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"GPTQ, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [ + "quant_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + return cls(weight_bits, group_size, desc_act) + + @classmethod + def get_packed_tensor_names(cls) -> List[str]: + return ["qzeros"] + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros", "scales"] + + def get_row_tp_tensor_names(self) -> List[str]: + if self.desc_act and self.group_size != -1: + return [] + if self.group_size == -1: + return ["qweight", "g_idx"] + return ["qweight", "qzeros", "scales"] + + def get_column_tp_tensor_names(self) -> List[str]: + return ["qweight", "qzeros", "scales", "bias"] + + def get_ignore_tensor_names(self) -> List[str]: + if self.desc_act and self.group_size != -1: + return [] + return ["g_idx"] diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 74de96842296..17b88e9fed8b 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -11,6 +11,7 @@ import numpy as np import torch from tqdm.auto import tqdm +from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.model_executor.quantization_utils import get_quant_class @@ -84,8 +85,13 @@ def convert_bin_to_safetensor_file( def get_quant_config( quantization: str, model_name_or_path: str, + hf_config: PretrainedConfig, cache_dir: Optional[str] = None, ) -> QuantizationConfig: + if quantization == "gptq" and hasattr(hf_config, "quantization_config"): + config = hf_config.quantization_config + return get_quant_class(quantization).from_config(config) + is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 2d2021d9fe95..ab5e79d26828 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -64,7 +64,8 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) - self.model = get_model(self.model_config) + self.model = get_model(self.model_config, + self.scheduler_config.max_num_batched_tokens) @torch.inference_mode() def profile_num_available_blocks( From 049a37c079a376d6d4f2753cf146227c573ddc89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 25 Sep 2023 20:35:09 +0800 Subject: [PATCH 02/18] fix bug in model loading --- vllm/model_executor/model_loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index a0b78566face..fb774c23fddf 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -96,8 +96,9 @@ def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION[ - model_config.quantization]: + if model_config.quantization is not None and ( + model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION[ + model_config.quantization]): model = model_class(model_config.hf_config, quant_config) else: model = model_class(model_config.hf_config) From 556357863bae780cd86dfef1f87dc1bf54ca00af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Wed, 27 Sep 2023 14:46:23 +0800 Subject: [PATCH 03/18] Add fallback kernel for desc act models --- csrc/quantization.cpp | 57 +++++++-- csrc/quantization/gptq/exllama_ext.cpp | 2 +- csrc/quantization/gptq/old_matmul.cpp | 25 ++++ csrc/quantization/gptq/old_matmul_kernel.cu | 111 ++++++++++++++++++ setup.py | 4 +- .../layers/quantized_linear/gptq.py | 58 +++++---- .../layers/quantized_linear/utils.py | 6 +- .../model_executor/quantization_utils/gptq.py | 4 +- vllm/model_executor/weight_utils.py | 9 +- vllm/worker/worker.py | 5 +- 10 files changed, 228 insertions(+), 53 deletions(-) create mode 100644 csrc/quantization/gptq/old_matmul.cpp create mode 100644 csrc/quantization/gptq/old_matmul_kernel.cu diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index ae8ff8b84dc3..7ffac0813569 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -8,24 +8,59 @@ torch::Tensor awq_gemm( torch::Tensor _zeros, int split_k_iters); -void gptq_set_tuning_params(int matmul_recons_thd, bool matmul_fused_remap, - bool matmul_no_half2); +void gptq_set_tuning_params( + int matmul_recons_thd, + bool matmul_fused_remap, + bool matmul_no_half2); -void gptq_prepare_buffers(torch::Device device, torch::Tensor temp_state, - torch::Tensor temp_dq); +void gptq_prepare_buffers( + torch::Device device, + torch::Tensor temp_state, + torch::Tensor temp_dq); -uintptr_t gptq_make_q4(torch::Tensor qweight, torch::Tensor qzeros, - torch::Tensor scales, torch::Tensor g_idx, int device); +uintptr_t gptq_make_q4( + torch::Tensor qweight, + torch::Tensor qzeros, + torch::Tensor scales, + torch::Tensor g_idx, + int device); -void gptq_q4_matmul(torch::Tensor x, uintptr_t w, torch::Tensor out); +void gptq_q4_matmul( + torch::Tensor x, + uintptr_t w, + torch::Tensor out); + +void gptq_descact_matmul( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def( "awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - m.def("gptq_set_tuning_params", &gptq_set_tuning_params, "gptq_set_tuning_params"); - m.def("gptq_prepare_buffers", &gptq_prepare_buffers, "gptq_prepare_buffers"); - m.def("gptq_make_q4", &gptq_make_q4, "gptq_make_q4"); - m.def("gptq_q4_matmul", &gptq_q4_matmul, "gptq_q4_matmul"); + m.def( + "gptq_set_tuning_params", + &gptq_set_tuning_params, + "Set tuning params for GPTQ"); + m.def( + "gptq_prepare_buffers", + &gptq_prepare_buffers, + "Prepare buffers for GPTQ"); + m.def( + "gptq_make_q4", + &gptq_make_q4, + "Preprocess weight for GPTQ"); + m.def( + "gptq_q4_matmul", + &gptq_q4_matmul, + "Quantized GEMM for GPTQ"); + m.def( + "gptq_descact_matmul", + &gptq_descact_matmul, + "Quantized GEMM for GPTQ for parallelized desc_act layer"); } \ No newline at end of file diff --git a/csrc/quantization/gptq/exllama_ext.cpp b/csrc/quantization/gptq/exllama_ext.cpp index 3332d6955eba..369fa0c05a8d 100644 --- a/csrc/quantization/gptq/exllama_ext.cpp +++ b/csrc/quantization/gptq/exllama_ext.cpp @@ -241,4 +241,4 @@ void gptq_column_remap width, (uint32_t*) x_map.data_ptr() ); -} +} \ No newline at end of file diff --git a/csrc/quantization/gptq/old_matmul.cpp b/csrc/quantization/gptq/old_matmul.cpp new file mode 100644 index 000000000000..f5b4be77701e --- /dev/null +++ b/csrc/quantization/gptq/old_matmul.cpp @@ -0,0 +1,25 @@ +#include +#include +#include + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +); + +void gptq_descact_matmul( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); +} diff --git a/csrc/quantization/gptq/old_matmul_kernel.cu b/csrc/quantization/gptq/old_matmul_kernel.cu new file mode 100644 index 000000000000..68c39fca8680 --- /dev/null +++ b/csrc/quantization/gptq/old_matmul_kernel.cu @@ -0,0 +1,111 @@ +#include +#include +#include +#include +#include +#include "cu_compat.cuh" + +const int BLOCKWIDTH = 256; +const int BLOCKHEIGHT = 32; + +__device__ inline unsigned int as_unsigned(int i) { + return *reinterpret_cast(&i); +} + +__device__ inline int as_int(int i) { + return *reinterpret_cast(&i); +} + +template +__global__ void VecQuant4MatMulKernel( + const scalar_t* __restrict__ vec, + const int* __restrict__ mat, + scalar_t* __restrict__ mul, + const scalar_t* __restrict__ scales, + const int* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int vec_height, + int height, + int width, + int zero_width +) { + int h = BLOCKHEIGHT * blockIdx.x; + int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; + int h_end = min(h + BLOCKHEIGHT, height); + + __shared__ scalar_t blockvec[BLOCKWIDTH]; + int i = width * h + w; + int g_h = h * 8; + int h_range = (h_end - h) * 8; + int k; + unsigned int g; + scalar_t w_tmp; + + + int z_w = w / 8; + int z_mod = (w % 8) * 4; + + float weight[BLOCKWIDTH]; + + if (w < width) { + for (k = 0; k < h_range; ++k) { + int k_w = (k / 8); + int k_bit = (k % 8) * 4; + + g = as_int(g_idx[g_h + k]); + scalar_t scale = scales[g * width + w]; + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); + weight[k] = scale * (w_tmp - zero); + } + } + + scalar_t res; + for (int b = 0; b < batch; ++b) { + res = 0; + + if (threadIdx.x < h_range) { + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; + } + __syncthreads(); + if (w < width) { + for (k = 0; k < h_range; ++k){ + res += weight[k] * blockvec[k]; + } + atomicAdd(&mul[b * width + w], res); + } + __syncthreads(); + } +} + +void vecquant4matmul_cuda( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) { + int batch = vec.size(0); + int vec_height = vec.size(1); + int height = mat.size(0); + int width = mat.size(1); + int zero_width = zeros.size(1); + + dim3 blocks( + (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, + (width + BLOCKWIDTH - 1) / BLOCKWIDTH + ); + dim3 threads(BLOCKWIDTH); + + AT_DISPATCH_FLOATING_TYPES( + vec.type(), "vecquant4matmul_cuda", ([&] { + VecQuant4MatMulKernel<<>>( + vec.data(), mat.data(), mul.data(), + scales.data(), zeros.data(), g_idx.data(), + batch, vec_height, height, width, zero_width + ); + }) + ); +} \ No newline at end of file diff --git a/setup.py b/setup.py index 7026300fd3f9..dd0a5c889b04 100644 --- a/setup.py +++ b/setup.py @@ -155,7 +155,9 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: "csrc/quantization/gptq/cuda_buffers.cu", "csrc/quantization/gptq/cuda_func/column_remap.cu", "csrc/quantization/gptq/cuda_func/q4_matmul.cu", - "csrc/quantization/gptq/cuda_func/q4_matrix.cu" + "csrc/quantization/gptq/cuda_func/q4_matrix.cu", + "csrc/quantization/gptq/old_matmul.cpp", + "csrc/quantization/gptq/old_matmul_kernel.cu" ], extra_compile_args={ "cxx": CXX_FLAGS, diff --git a/vllm/model_executor/layers/quantized_linear/gptq.py b/vllm/model_executor/layers/quantized_linear/gptq.py index 4d9c3fa4284d..fd2e8b6ae2af 100644 --- a/vllm/model_executor/layers/quantized_linear/gptq.py +++ b/vllm/model_executor/layers/quantized_linear/gptq.py @@ -6,8 +6,6 @@ from vllm import quantization_ops from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.parallel_utils.tensor_parallel.mappings import ( - gather_from_tensor_model_parallel_region) class GPTQLinear(torch.nn.Module): @@ -22,6 +20,7 @@ def __init__(self, self.input_size = input_size self.output_size = output_size self.quant_config = quant_config + self.use_exllama = True group_size = self.quant_config.group_size if ( self.quant_config.group_size != -1) else self.input_size self.qweight = Parameter( @@ -103,6 +102,7 @@ def create_weights(self, dtype: torch.dtype) -> None: assert self.input_size % self.quant_config.pack_factor == 0 assert (self.output_size_per_partition % self.quant_config.pack_factor == 0) + self.use_exllama = True group_size = self.quant_config.group_size if ( self.quant_config.group_size != -1) else self.input_size @@ -179,15 +179,16 @@ def create_weights(self, dtype: torch.dtype) -> None: assert (self.input_size_per_partition % self.quant_config.pack_factor == 0) assert self.output_size % self.quant_config.pack_factor == 0 - # Ignore tensor parallel when group_size != -1 and desc_act - if self.quant_config.desc_act and self.quant_config.group_size != -1: - self.input_size_per_partition = self.input_size - self.parallel = False - else: - self.parallel = True group_size = self.quant_config.group_size if ( self.quant_config.group_size != -1 ) else self.input_size_per_partition + if self.world_size > 1 and (self.quant_config.desc_act + and self.quant_config.group_size != -1): + group_number = self.input_size // group_size + self.use_exllama = False + else: + group_number = self.input_size_per_partition // group_size + self.use_exllama = True self.qweight = Parameter( torch.empty( self.input_size_per_partition // self.quant_config.pack_factor, @@ -199,7 +200,7 @@ def create_weights(self, dtype: torch.dtype) -> None: ) self.qzeros = Parameter( torch.empty( - self.input_size_per_partition // group_size, + group_number, self.output_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, @@ -208,7 +209,7 @@ def create_weights(self, dtype: torch.dtype) -> None: ) self.scales = Parameter( torch.empty( - self.input_size_per_partition // group_size, + group_number, self.output_size, device="cuda", dtype=dtype, @@ -228,6 +229,8 @@ def create_weights(self, dtype: torch.dtype) -> None: ) def post_init(self): + if not self.use_exllama: + return assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None @@ -244,26 +247,19 @@ def post_init(self): def apply_weights(self, x: torch.Tensor) -> torch.Tensor: out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) - output = torch.empty((x.shape[0], self.qweight.shape[-1]), - dtype=torch.float16, - device=x.device) - quantization_ops.gptq_q4_matmul(reshaped_x, self.q4, output) - return output.reshape(out_shape) - def forward(self, input_): - # Set up backprop all-reduce. - if self.parallel: - return super().forward(input_) - if self.input_is_parallel: - input_ = gather_from_tensor_model_parallel_region(input_) - output_ = self.apply_weights(input_) - if not self.reduce_results and self.world_size > 1: - output_ = output_ / self.world_size - - if not self.skip_bias_add: - output = output_ + self.bias if self.bias is not None else output_ - output_bias = None + if self.use_exllama: + output = torch.empty((x.shape[0], self.qweight.shape[-1]), + dtype=torch.float16, + device=x.device) + quantization_ops.gptq_q4_matmul(reshaped_x, self.q4, output) else: - output = output_ - output_bias = self.bias - return output, output_bias + output = torch.zeros((x.shape[0], self.qweight.shape[-1]), + dtype=torch.float32, + device=x.device) + quantization_ops.gptq_descact_matmul(reshaped_x.float(), + self.qweight, output, + self.scales.float(), + self.qzeros, self.g_idx) + output = output.half() + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantized_linear/utils.py b/vllm/model_executor/layers/quantized_linear/utils.py index c34a2d0e197f..679084824cfb 100644 --- a/vllm/model_executor/layers/quantized_linear/utils.py +++ b/vllm/model_executor/layers/quantized_linear/utils.py @@ -17,9 +17,9 @@ def quant_post_init(model, max_input_length: Optional[int] = None): model_uses_exllama = False use_act_order = False for _, submodule in model.named_modules(): - if isinstance( - submodule, - (GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)): + if isinstance(submodule, + (GPTQColumnParallelLinear, GPTQRowParallelLinear, + GPTQLinear)) and submodule.use_exllama: model_uses_exllama = True device = submodule.qweight.device if device not in device_to_buffers_size: diff --git a/vllm/model_executor/quantization_utils/gptq.py b/vllm/model_executor/quantization_utils/gptq.py index 28d75d28837e..012d6c5c1c9d 100644 --- a/vllm/model_executor/quantization_utils/gptq.py +++ b/vllm/model_executor/quantization_utils/gptq.py @@ -68,9 +68,9 @@ def get_transposed_tensor_names(cls) -> List[str]: def get_row_tp_tensor_names(self) -> List[str]: if self.desc_act and self.group_size != -1: - return [] - if self.group_size == -1: return ["qweight", "g_idx"] + if self.group_size == -1: + return ["qweight"] return ["qweight", "qzeros", "scales"] def get_column_tp_tensor_names(self) -> List[str]: diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index bb47098a1745..bed5d8586e5c 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -288,10 +288,15 @@ def load_tensor_parallel_weights( break for p in row_parallel_weight_names: if p in param_name: - shard_size = param.shape[1] + shard_size = param.shape[-1] start_idx = tensor_model_parallel_rank * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size - loaded_weight = loaded_weight[:, start_idx:end_idx] + if isinstance(loaded_weight, torch.Tensor): + loaded_weight = loaded_weight[..., start_idx:end_idx] + else: + index = [slice(None)] * (len(loaded_weight.get_shape()) - + 1) + [slice(start_idx, end_idx)] + loaded_weight = loaded_weight[index] break loaded_weight = convert_pyslice_to_tensor(loaded_weight) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 95da6f84a89f..e7b1d5fa0b6f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -225,8 +225,9 @@ def _prepare_inputs( # Optimization: Pad the input length to be a multiple of 8. # This is required for utilizing the Tensor Cores in NVIDIA GPUs. - input_tokens = _pad_to_alignment(input_tokens, multiple_of=8) - input_positions = _pad_to_alignment(input_positions, multiple_of=8) + if self.model_config.quantization is None: + input_tokens = _pad_to_alignment(input_tokens, multiple_of=8) + input_positions = _pad_to_alignment(input_positions, multiple_of=8) # Convert to tensors. tokens_tensor = torch.tensor(input_tokens, From 0470121bd4be1a565ea5cde70eef859f84c72fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Wed, 27 Sep 2023 22:36:59 +0800 Subject: [PATCH 04/18] Fix engine args and opt model --- vllm/engine/arg_utils.py | 2 +- vllm/model_executor/models/opt.py | 2 +- vllm/model_executor/quantization_utils/gptq.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 65a5d74fa56b..d7e9f2685564 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -156,7 +156,7 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', None], + choices=['awq', 'gptq', None], default=None, help='Method used to quantize the weights') return parser diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c3d2dd4118fa..61a19ecd2c2c 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -286,7 +286,7 @@ def forward( class OPTForCausalLM(nn.Module): - def __init__(self, config, quant_config): + def __init__(self, config, quant_config=None): super().__init__() self.config = config self.quant_config = quant_config diff --git a/vllm/model_executor/quantization_utils/gptq.py b/vllm/model_executor/quantization_utils/gptq.py index 012d6c5c1c9d..42783220bbc8 100644 --- a/vllm/model_executor/quantization_utils/gptq.py +++ b/vllm/model_executor/quantization_utils/gptq.py @@ -48,7 +48,7 @@ def get_min_capability(cls) -> int: @classmethod def get_config_filenames(cls) -> List[str]: return [ - "quant_config.json", + "quantize_config.json", ] @classmethod From f9d0ccc170b82199f102c55c0618b60bb084e138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 9 Oct 2023 16:52:25 +0800 Subject: [PATCH 05/18] Add mistral model --- vllm/model_executor/model_loader.py | 20 +++++++++--- vllm/model_executor/models/aquila.py | 3 +- vllm/model_executor/models/baichuan.py | 3 +- vllm/model_executor/models/bloom.py | 3 +- vllm/model_executor/models/falcon.py | 6 ++-- vllm/model_executor/models/gpt2.py | 3 +- vllm/model_executor/models/gpt_bigcode.py | 3 +- vllm/model_executor/models/gpt_j.py | 3 +- vllm/model_executor/models/gpt_neox.py | 11 ++++--- vllm/model_executor/models/internlm.py | 3 +- vllm/model_executor/models/mistral.py | 37 +++++++++++------------ vllm/model_executor/models/mpt.py | 3 +- vllm/model_executor/models/opt.py | 3 +- vllm/model_executor/models/qwen.py | 23 +++++++------- 14 files changed, 63 insertions(+), 61 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index cebaada01442..d0ddf8bca7e7 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -37,11 +37,21 @@ _MODEL_CLASSES_SUPPORT_QUANTIZATION = { "awq": [LlamaForCausalLM], "gptq": [ - LlamaForCausalLM, QWenLMHeadModel, BaiChuanForCausalLM, - BaichuanForCausalLM, BloomForCausalLM, GPT2LMHeadModel, - GPTJForCausalLM, GPTNeoXForCausalLM, GPTBigCodeForCausalLM, - InternLMForCausalLM, FalconForCausalLM, AquilaForCausalLM, - OPTForCausalLM, MPTForCausalLM + LlamaForCausalLM, + QWenLMHeadModel, + BaiChuanForCausalLM, + BaichuanForCausalLM, + BloomForCausalLM, + GPT2LMHeadModel, + GPTJForCausalLM, + GPTNeoXForCausalLM, + GPTBigCodeForCausalLM, + InternLMForCausalLM, + FalconForCausalLM, + AquilaForCausalLM, + OPTForCausalLM, + MPTForCausalLM, + MistralForCausalLM, ], } diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 7e40012f4ca1..efd20bc92a50 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -42,8 +42,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.aquila import AquilaConfig diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 9f5163bc0d96..826a45e20b05 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -42,8 +42,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index ebecf5551b94..e6859658b138 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -39,8 +39,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index f4247e5970d2..eed135b54120 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -39,8 +39,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.sequence import SamplerOutput @@ -246,7 +245,8 @@ def __init__(self, bias=config.bias, gather_output=False, skip_bias_add=True, - quant_config=quant_config) + quant_config=quant_config, + ) self.act = nn.GELU() self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index bfd536a036cc..82243c1f6da3 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -39,8 +39,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 4e1a5cc6fafe..bff7835ebb20 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -40,8 +40,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index be0216db4520..ad1d30fc0ad7 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -38,8 +38,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 46d48ff76980..aa313a1726f1 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -38,8 +38,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -65,12 +64,14 @@ def __init__(self, config.hidden_size, 3 * config.hidden_size, gather_output=False, - quant_config=quant_config) + quant_config=quant_config, + ) self.dense = ParallelLinear.row( config.hidden_size, config.hidden_size, input_is_parallel=True, - quant_config=quant_config) + quant_config=quant_config, + ) scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 @@ -114,7 +115,7 @@ def __init__( config.hidden_size, config.intermediate_size, gather_output=False, - quant_config=quant_config + quant_config=quant_config, ) self.dense_4h_to_h = ParallelLinear.row( config.intermediate_size, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 1b0be7487f73..70d6eefeec53 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -14,8 +14,7 @@ from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.weight_utils import ( hf_model_weights_iterator, load_padded_tensor_parallel_vocab, load_tensor_parallel_weights, convert_pyslice_to_tensor, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index d298ea7d2be4..696f456d9ea4 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -42,7 +42,8 @@ from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) + load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, + get_parallel_weight) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mistral import MistralConfig @@ -289,28 +290,16 @@ def forward( input_metadata) return next_tokens - _column_parallel_layers = [] - _row_parallel_layers = ["o_proj", "down_proj"] + column_parallel_layers = [] + row_parallel_layers = ["o_proj", "down_proj"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - if self.quant_config is None: - weight_suffixes = ["weight"] - else: - weight_suffixes = self.quant_config.get_tp_tensor_names() - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - + (column_parallel_weights, row_parallel_weights, + ignore_weight_suffixes) = get_parallel_weight(self) tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) @@ -330,6 +319,8 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + if any(name.endswith(suffix) for suffix in ignore_weight_suffixes): + continue is_packed = False is_transposed = False @@ -344,7 +335,10 @@ def load_weights(self, for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] + name = name.replace(weight_name, "qkv_proj") + if name not in state_dict or "g_idx" in name: + break + param = state_dict[name] if is_transposed: param = param.T @@ -368,7 +362,10 @@ def load_weights(self, for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] + name = name.replace(weight_name, "gate_up_proj") + if "g_idx" in name or name not in state_dict: + break + param = state_dict[name] if is_transposed: param = param.T @@ -385,6 +382,8 @@ def load_weights(self, if is_gate_up_weight: continue + if name not in state_dict: + continue param = state_dict[name] if is_transposed: param = param.T diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 4d51fb9eb9e9..fd46bb19a9dd 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -18,8 +18,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.mpt import MPTConfig diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 521f71c84f73..50ffa5690ebc 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -39,8 +39,7 @@ get_parallel_weight) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 656bcb55b6ab..b3daa5c311b8 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -31,9 +31,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding, -) +from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.qwen import QWenConfig @@ -77,13 +75,16 @@ def forward(self, x): class QWenAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - max_position_embeddings: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - quant_config: Optional[QuantizationConfig] = None,): + + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.hidden_size = hidden_size tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( @@ -252,7 +253,7 @@ def __init__(self, vocab_size, bias=False, gather_output=False, - quant_config=None + quant_config=None, ) self.sampler = Sampler(config.vocab_size) From cbf94333be5f63b201a1b93b5b7fa2143ad31981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Wed, 11 Oct 2023 13:31:23 +0800 Subject: [PATCH 06/18] Fix bug in gpt layer --- vllm/model_executor/layers/quantized_linear/gptq.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/gptq.py b/vllm/model_executor/layers/quantized_linear/gptq.py index fd2e8b6ae2af..3b6775a831c2 100644 --- a/vllm/model_executor/layers/quantized_linear/gptq.py +++ b/vllm/model_executor/layers/quantized_linear/gptq.py @@ -4,8 +4,8 @@ from torch.nn.parameter import Parameter from vllm import quantization_ops -from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( - ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, + RowParallelLinear) class GPTQLinear(torch.nn.Module): @@ -182,8 +182,8 @@ def create_weights(self, dtype: torch.dtype) -> None: group_size = self.quant_config.group_size if ( self.quant_config.group_size != -1 ) else self.input_size_per_partition - if self.world_size > 1 and (self.quant_config.desc_act - and self.quant_config.group_size != -1): + if self.tp_size > 1 and (self.quant_config.desc_act + and self.quant_config.group_size != -1): group_number = self.input_size // group_size self.use_exllama = False else: From 9a994610bab2e003666d1718db9d2d0870b6dc37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Tue, 24 Oct 2023 22:09:54 +0800 Subject: [PATCH 07/18] Fix squeezellm --- benchmarks/benchmark_latency.py | 2 +- benchmarks/benchmark_throughput.py | 2 +- vllm/model_executor/model_loader.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e560cb1fbfc0..bc0660af8ccf 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -70,7 +70,7 @@ def run_to_completion(profile: bool = False): parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + choices=['awq', 'squeezellm', 'gptq', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index fc578b497286..2d1af7313823 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -201,7 +201,7 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + choices=['awq', 'squeezellm', 'gptq', None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 0293e5ece6fc..f0cca2f86c96 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -37,6 +37,7 @@ # FIXME(woosuk): Remove this once all models support quantization. _MODEL_CLASSES_SUPPORT_QUANTIZATION = { "awq": [LlamaForCausalLM, MistralForCausalLM], + "squeezellm": [LlamaForCausalLM, MistralForCausalLM], "gptq": [ LlamaForCausalLM, QWenLMHeadModel, From 2593dfe8e8e0af059c4ef9d4b4ce78a2303dc447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Thu, 2 Nov 2023 16:30:39 +0800 Subject: [PATCH 08/18] Use exllama v2 kernels for better performance --- csrc/quantization.cpp | 57 ++- .../gptq/{cu_compat.cuh => compat.cuh} | 6 +- csrc/quantization/gptq/cuda_buffers.cu | 75 ---- csrc/quantization/gptq/cuda_buffers.cuh | 55 --- .../gptq/cuda_func/column_remap.cu | 63 ---- .../gptq/cuda_func/column_remap.cuh | 19 - csrc/quantization/gptq/cuda_func/q4_matmul.cu | 260 ------------- .../quantization/gptq/cuda_func/q4_matmul.cuh | 43 --- csrc/quantization/gptq/cuda_func/q4_matrix.cu | 225 ------------ .../quantization/gptq/cuda_func/q4_matrix.cuh | 53 --- csrc/quantization/gptq/exllama_ext.cpp | 263 ++++---------- csrc/quantization/gptq/hip_compat.cuh | 49 --- csrc/quantization/gptq/matrix.cuh | 294 --------------- csrc/quantization/gptq/matrix_view.cuh | 121 +++++++ csrc/quantization/gptq/old_matmul.cpp | 25 -- csrc/quantization/gptq/old_matmul_kernel.cu | 16 +- csrc/quantization/gptq/q_gemm.cu | 168 +++++++++ csrc/quantization/gptq/q_gemm.cuh | 33 ++ csrc/quantization/gptq/q_gemm_kernel_gptq.cuh | 217 +++++++++++ csrc/quantization/gptq/q_matrix.cu | 341 ++++++++++++++++++ csrc/quantization/gptq/q_matrix.cuh | 58 +++ csrc/quantization/gptq/qdq_4.cuh | 222 ++++++++++++ csrc/quantization/gptq/qdq_util.cuh | 51 +++ csrc/quantization/gptq/tuning.h | 13 - csrc/quantization/gptq/util.cuh | 33 -- setup.py | 7 +- .../layers/quantized_linear/gptq.py | 162 +++++++-- .../layers/quantized_linear/utils.py | 92 ++--- .../quantization_utils/squeezellm.py | 6 +- 29 files changed, 1473 insertions(+), 1554 deletions(-) rename csrc/quantization/gptq/{cu_compat.cuh => compat.cuh} (92%) delete mode 100644 csrc/quantization/gptq/cuda_buffers.cu delete mode 100644 csrc/quantization/gptq/cuda_buffers.cuh delete mode 100644 csrc/quantization/gptq/cuda_func/column_remap.cu delete mode 100644 csrc/quantization/gptq/cuda_func/column_remap.cuh delete mode 100644 csrc/quantization/gptq/cuda_func/q4_matmul.cu delete mode 100644 csrc/quantization/gptq/cuda_func/q4_matmul.cuh delete mode 100644 csrc/quantization/gptq/cuda_func/q4_matrix.cu delete mode 100644 csrc/quantization/gptq/cuda_func/q4_matrix.cuh delete mode 100644 csrc/quantization/gptq/hip_compat.cuh delete mode 100644 csrc/quantization/gptq/matrix.cuh create mode 100644 csrc/quantization/gptq/matrix_view.cuh delete mode 100644 csrc/quantization/gptq/old_matmul.cpp create mode 100644 csrc/quantization/gptq/q_gemm.cu create mode 100644 csrc/quantization/gptq/q_gemm.cuh create mode 100644 csrc/quantization/gptq/q_gemm_kernel_gptq.cuh create mode 100644 csrc/quantization/gptq/q_matrix.cu create mode 100644 csrc/quantization/gptq/q_matrix.cuh create mode 100644 csrc/quantization/gptq/qdq_4.cuh create mode 100644 csrc/quantization/gptq/qdq_util.cuh delete mode 100644 csrc/quantization/gptq/tuning.h delete mode 100644 csrc/quantization/gptq/util.cuh diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index 01e3de19cf99..b9919d868f16 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -8,27 +8,22 @@ torch::Tensor awq_gemm( torch::Tensor _zeros, int split_k_iters); -void gptq_set_tuning_params( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2); - -void gptq_prepare_buffers( - torch::Device device, - torch::Tensor temp_state, - torch::Tensor temp_dq); - -uintptr_t gptq_make_q4( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device); - -void gptq_q4_matmul( - torch::Tensor x, - uintptr_t w, - torch::Tensor out); +uintptr_t make_q_matrix( + torch::Tensor q_weight, + torch::Tensor q_perm, + torch::Tensor q_invperm, + torch::Tensor gptq_qzeros, + torch::Tensor gptq_scales, + torch::Tensor gptq_g_idx, + torch::Tensor temp_dq +); + +void gemm_half_q_half( + torch::Tensor a, + uintptr_t b, + torch::Tensor c, + bool force_cuda +); void gptq_descact_matmul( torch::Tensor vec, @@ -50,21 +45,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &awq_gemm, "Quantized GEMM for AWQ"); m.def( - "gptq_set_tuning_params", - &gptq_set_tuning_params, - "Set tuning params for GPTQ"); - m.def( - "gptq_prepare_buffers", - &gptq_prepare_buffers, - "Prepare buffers for GPTQ"); - m.def( - "gptq_make_q4", - &gptq_make_q4, - "Preprocess weight for GPTQ"); + "make_q_matrix", + &make_q_matrix, + "make_q_matrix"); m.def( - "gptq_q4_matmul", - &gptq_q4_matmul, - "Quantized GEMM for GPTQ"); + "gemm_half_q_half", + &gemm_half_q_half, + "gemm_half_q_half"); m.def( "gptq_descact_matmul", &gptq_descact_matmul, diff --git a/csrc/quantization/gptq/cu_compat.cuh b/csrc/quantization/gptq/compat.cuh similarity index 92% rename from csrc/quantization/gptq/cu_compat.cuh rename to csrc/quantization/gptq/compat.cuh index c5258813e147..12684ff8b59f 100644 --- a/csrc/quantization/gptq/cu_compat.cuh +++ b/csrc/quantization/gptq/compat.cuh @@ -1,7 +1,5 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_compat_cuh -#define _cuda_compat_cuh +#ifndef _compat_cuh +#define _compat_cuh // atomicAdd for half types, to support CC < 7.x diff --git a/csrc/quantization/gptq/cuda_buffers.cu b/csrc/quantization/gptq/cuda_buffers.cu deleted file mode 100644 index 4416027c8387..000000000000 --- a/csrc/quantization/gptq/cuda_buffers.cu +++ /dev/null @@ -1,75 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#define _cuda_buffers_cu -#include "cuda_buffers.cuh" - -CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL}; -// __constant__ half2 q4_table[16][256]; -// half2 q4_table_host[16][256]; -// bool q4_table_init = false; - -CudaBuffers::CudaBuffers -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) : - device(_device), - temp_state_size(_temp_state_size), - temp_state(_temp_state), - temp_dq(_temp_dq) -{ - cudaSetDevice(_device); - - cudaStreamCreate(&alt_stream_1); - cudaStreamCreate(&alt_stream_2); - cudaStreamCreate(&alt_stream_3); - cudaEventCreate(&alt_stream_1_done); - cudaEventCreate(&alt_stream_2_done); - cudaEventCreate(&alt_stream_3_done); -} - -CudaBuffers::~CudaBuffers() -{ - cudaStreamDestroy(alt_stream_1); - cudaStreamDestroy(alt_stream_2); - cudaStreamDestroy(alt_stream_3); - cudaEventDestroy(alt_stream_1_done); - cudaEventDestroy(alt_stream_2_done); - cudaEventDestroy(alt_stream_3_done); -} - -CudaBuffers* get_buffers(const int device_index) -{ - return g_buffers[device_index]; -} - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -) -{ - CudaBuffers* buffers = new CudaBuffers - ( - _device, - _temp_state_size, - _temp_state, - _temp_dq - ); - - g_buffers[_device] = buffers; -} - -void cleanup_buffers_cuda() -{ - for (int i = 0; i < CUDA_MAX_DEVICES; i++) - { - if (!g_buffers[i]) continue; - delete g_buffers[i]; - g_buffers[i] = NULL; - } -} diff --git a/csrc/quantization/gptq/cuda_buffers.cuh b/csrc/quantization/gptq/cuda_buffers.cuh deleted file mode 100644 index 0bf2057c665c..000000000000 --- a/csrc/quantization/gptq/cuda_buffers.cuh +++ /dev/null @@ -1,55 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _cuda_buffers_cuh -#define _cuda_buffers_cuh - -#include -#include -#include -#include - -const int CUDA_MAX_DEVICES = 16; - -// #ifndef _cuda_buffers_cu -// extern __constant__ half2 q4_table[16][256]; -// #endif - -class CudaBuffers -{ -public: - int device; - - half* temp_state; // [max_hidden_rows * intermediate_size] - int temp_state_size; - half* temp_dq; // size of largest quant tensor * 8 - - cudaStream_t alt_stream_1; - cudaStream_t alt_stream_2; - cudaStream_t alt_stream_3; - cudaEvent_t alt_stream_1_done; - cudaEvent_t alt_stream_2_done; - cudaEvent_t alt_stream_3_done; - - CudaBuffers - ( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq - ); - ~CudaBuffers(); -}; - -CudaBuffers* get_buffers(const int device_index); - -void prepare_buffers_cuda -( - int _device, - int _temp_state_size, - half* _temp_state, - half* _temp_dq -); - -void cleanup_buffers_cuda(); - -#endif diff --git a/csrc/quantization/gptq/cuda_func/column_remap.cu b/csrc/quantization/gptq/cuda_func/column_remap.cu deleted file mode 100644 index 30e4039dd2e9..000000000000 --- a/csrc/quantization/gptq/cuda_func/column_remap.cu +++ /dev/null @@ -1,63 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "column_remap.cuh" -#include "../util.cuh" - -const int SHUF_BLOCKSIZE_X = 256; -const int SHUF_BLOCKSIZE_Y = 16; - -__global__ void column_remap_kernel -( - const half* __restrict__ x, - half* __restrict__ x_new, - const int x_width, - const int x_height, - const uint32_t* x_map -) -{ - int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y; - if (x_column >= x_width) return; - //if (x_row >= x_height) return; - - int x_stride = x_width; - int x_idx = x_row * x_stride + x_column; - - int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height); - int x_idx_end = x_row_end * x_stride + x_column; - - int s_column = x_map[x_column]; - int s_idx = x_row * x_stride + s_column; - - while (x_idx < x_idx_end) - { - x_new[x_idx] = x[s_idx]; - x_idx += x_stride; - s_idx += x_stride; - } -} - -// Remap columns in x to correspond to sequential group index before matmul -// -// perform x -> seq_x such that seq_x @ seq_w == x @ w - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -) -{ - dim3 threads(SHUF_BLOCKSIZE_X, 1, 1); - - dim3 blocks - ( - (x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X, - (x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y, - 1 - ); - - column_remap_kernel<<>>(x, x_new, x_width, x_height, x_map); -} diff --git a/csrc/quantization/gptq/cuda_func/column_remap.cuh b/csrc/quantization/gptq/cuda_func/column_remap.cuh deleted file mode 100644 index 6571c17d6fd5..000000000000 --- a/csrc/quantization/gptq/cuda_func/column_remap.cuh +++ /dev/null @@ -1,19 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _column_remap_cuh -#define _column_remap_cuh - -#include -#include -#include - -void column_remap_cuda -( - const half* x, - half* x_new, - const int x_height, - const int x_width, - const uint32_t* x_map -); - -#endif \ No newline at end of file diff --git a/csrc/quantization/gptq/cuda_func/q4_matmul.cu b/csrc/quantization/gptq/cuda_func/q4_matmul.cu deleted file mode 100644 index 0ee6e16dc862..000000000000 --- a/csrc/quantization/gptq/cuda_func/q4_matmul.cu +++ /dev/null @@ -1,260 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matmul.cuh" -#include "column_remap.cuh" -#include "../util.cuh" -#include "../matrix.cuh" -#include "../cu_compat.cuh" -#include "../cuda_buffers.cuh" -#if defined(USE_ROCM) -#include "../hip_compat.cuh" -#endif - -const int THREADS_X = 32; // Block size and thread count along columns in w and out -const int THREADS_Y = 1; // Block size and thread count along rows in x and out - -typedef void (*fp_q4_matmul_kernel) -( - const half*, - const uint32_t*, - half*, - const half*, - const uint32_t*, - const int, - const int, - const int, - const int, - const int, - const uint32_t*, - bool -); - -template -__global__ void q4_matmul_kernel -( - const half* __restrict__ x, - const uint32_t* __restrict__ w, - half* __restrict__ out, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int dim, - const int width, - const int groupsize, - const int block_size_z, - const uint32_t* __restrict__ x_map, - bool no_zero -) -{ - // Start of block - - int x_column = block_size_z * blockIdx.z; - int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - - int w_column = THREADS_X * blockIdx.x + threadIdx.x; - int x_row = THREADS_Y * blockIdx.y + threadIdx.y; - - int iterations = (x_column_end - x_column) / 8; - - // Views - - MatrixView_half x_(x, height, dim); - MatrixView_half w_scales_(w_scales, dim / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, dim / groupsize, width); - MatrixView_q4_column w_(w, dim, width); - MatrixView_half_rw out_(out, height, width); - - // Zero output - - if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) - { - *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; - __syncthreads(); - } - - // Loop over part of x row (and w column) - - half2 acc = {}; - half acc_h = {}; - - if constexpr (use_groupsize) - { - // For quant matrices where groupsize divides BLOCK_SIZE_Z we always start on a group boundary, so this - // could be slightly faster - - for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) - { - if constexpr (use_half2) - { - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - else - { - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); - } - } - } - else - { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - - for (int k = x_column; k < x_column + iterations * 8; k += 8) - { - if constexpr (use_half2) - { - int group = k / groupsize; - half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - else - { - int group = k / groupsize; - half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); - } - } - } - - // Add to block result - - if constexpr (use_half2) - { - half result = __hadd(__low2half(acc), __high2half(acc)); - atomicAdd(out_.item_ptr(x_row, w_column), result); - } - else - { - atomicAdd(out_.item_ptr(x_row, w_column), acc_h); - } -} - -fp_q4_matmul_kernel q4_matmul_kernel_pick(ExLlamaTuning* tuningParams, int block_size_z, int groupsize, uint32_t* x_map) -{ - // - if (tuningParams->matmul_no_half2) { - if (block_size_z % groupsize == 0) { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } else { - if (block_size_z % groupsize == 0) - { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } else { - if (x_map) return q4_matmul_kernel; - else return q4_matmul_kernel; - } - } -}; - -// Compute y = x @ w - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero, - cudaStream_t alt_stream -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - - uint32_t* x_map = w->cuda_x_map; - const half* x_mapped = x; - if (x_map && !tuningParams->matmul_fused_remap && !alt_stream) - { - CudaBuffers* buffers = get_buffers(w->device); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - x_map = NULL; - } - - int block_size_z; - if (w->width == 4096) block_size_z = 384; // 7B - else if (w->width == 11008) block_size_z = 256; - else if (w->width == 5120) block_size_z = 384; // 13B - else if (w->width == 13824) block_size_z = 256; - else if (w->width == 6656) block_size_z = 256; // 33B - else if (w->width == 17920) block_size_z = 128; - else block_size_z = 256; - - //if (!no_zero) cudaMemsetAsync(out, 0, x_height * w->width * sizeof(half)); - - dim3 threads(THREADS_X, THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height + threads.y - 1) / threads.y, - (dim + block_size_z - 1) / block_size_z - ); - - fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); -} - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero -) -{ - int height = x_height; - int dim = w->height; - int width = w->width; - - cudaSetDevice(w->device); - CudaBuffers* buffers = get_buffers(w->device); - - const half* x_mapped = x; - if (w->cuda_x_map) - { - TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)"); - column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); - x_mapped = buffers->temp_state; - } - - w->reconstruct(buffers->temp_dq); - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 - const float alpha = 1.0f; - const float beta = no_zero ? 1.0f : 0.0f; - cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, - x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); -#else - const half alpha = __float2half(1.0f); - const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); - cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); -#endif -} diff --git a/csrc/quantization/gptq/cuda_func/q4_matmul.cuh b/csrc/quantization/gptq/cuda_func/q4_matmul.cuh deleted file mode 100644 index 49967648f2fd..000000000000 --- a/csrc/quantization/gptq/cuda_func/q4_matmul.cuh +++ /dev/null @@ -1,43 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matmul_cuh -#define _q4_matmul_cuh - -#include -#include -#include -#include -#include - -#include "q4_matrix.cuh" -#include "../tuning.h" - -// Workaround for hipify_python using rocblas instead of hipblas. -#if defined(USE_ROCM) -#include -#define rocblas_handle hipblasHandle_t -#endif - -void q4_matmul_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - const Q4Matrix* w, - half* out, - bool no_zero = false, - cudaStream_t alt_stream = NULL -); - -void q4_matmul_recons_cuda -( - ExLlamaTuning* tuningParams, - const half* x, - const int x_height, - Q4Matrix* w, - half* out, - const cublasHandle_t handle, - bool no_zero = false -); - -#endif diff --git a/csrc/quantization/gptq/cuda_func/q4_matrix.cu b/csrc/quantization/gptq/cuda_func/q4_matrix.cu deleted file mode 100644 index 2b3600e0fbc2..000000000000 --- a/csrc/quantization/gptq/cuda_func/q4_matrix.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#include "q4_matrix.cuh" -#include -#include "../util.cuh" -#include "../matrix.cuh" - -using namespace std; - -const int UNSHUF_BLOCKSIZE_X = 64; - -const int RECONS_THREADS_X = 64; // Block size and thread count along columns in out, each thread converts 1 column -const int RECONS_THREADS_Y = 1; // Block size and thread count along rows in x and out, each thread converts 8 rows - -vector g_q4_matrices; - -void g_q4_keep_matrix(Q4Matrix* m) -{ - g_q4_matrices.push_back(m); -} - -void g_q4_free_matrices() -{ - for (const auto& m : g_q4_matrices) delete m; - g_q4_matrices.clear(); -} - -Q4Matrix::Q4Matrix -( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device -) : - height(_height), - width(_width), - groups(_groups), - device(_device) -{ - cudaSetDevice(device); - - cuda_qweight = _qweight; - cuda_qzeros = _qzeros; - cuda_scales = _scales; - - groupsize = height / groups; - - if (_g_idx) make_sequential(_g_idx); -} - -Q4Matrix::~Q4Matrix() -{ -} - -// Make sequential - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint32_t* __restrict__ x_map, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = UNSHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - - int w_new2_row = blockIdx.y; - - int x_map_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = x_map[x_map_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -void Q4Matrix::make_sequential(const uint32_t* cpu_g_idx) -{ - uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - cudaMalloc(&cuda_x_map, height * sizeof(uint32_t)); // TODO: Should probably be allocated in PyTorch - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Move to CUDA - - cudaMemcpyAsync(cuda_x_map, cpu_x_map, height * sizeof(uint32_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 threads(UNSHUF_BLOCKSIZE_X, 1, 1); - dim3 blocks - ( - (width + UNSHUF_BLOCKSIZE_X * 2 - 1) / (UNSHUF_BLOCKSIZE_X * 2), - height / 8, - 1 - ); - - make_sequential_kernel<<>>(cuda_qweight, cuda_new_qweight, cuda_x_map, height / 8, width); - - // Replace qweights - - cudaMemcpyAsync(cuda_qweight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); -} - -__global__ void reconstruct_kernel -( - const uint32_t* __restrict__ w, - half* __restrict__ out, // (y) - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int height, - const int width, - const int groupsize -) -{ - // Start of block - - int column = RECONS_THREADS_X * blockIdx.x + threadIdx.x; - int row = (RECONS_THREADS_Y * blockIdx.y + threadIdx.y) * 8; - if (column >= width) return; - - // Views - - MatrixView_q4_column w_(w, height, width); - MatrixView_half_rw out_(out, height, width); - MatrixView_half w_scales_(w_scales, height / groupsize, width); - MatrixView_q4_row w_zeros_(w_zeros, height / groupsize, width); - - // Groupsize version - - int group = row / groupsize; - - half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; - - uint32_t w_read = w_.item_uint32_t(row, column); - half* out_ptr = out_.item_ptr(row, column); - - #pragma unroll - for (int s = 0; s < 32; s += 4) - { - half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); - *out_ptr = w_item; out_ptr += out_.width; - } -} - -void Q4Matrix::reconstruct(half* out) -{ - dim3 threads(RECONS_THREADS_X, RECONS_THREADS_Y, 1); - - dim3 blocks - ( - (width + threads.x - 1) / threads.x, - (height / 8 + threads.y - 1) / threads.y, - 1 - ); - - reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} \ No newline at end of file diff --git a/csrc/quantization/gptq/cuda_func/q4_matrix.cuh b/csrc/quantization/gptq/cuda_func/q4_matrix.cuh deleted file mode 100644 index 50cb72a41518..000000000000 --- a/csrc/quantization/gptq/cuda_func/q4_matrix.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _q4_matrix_cuh -#define _q4_matrix_cuh - -#include -#include -#include - -class Q4Matrix -{ -public: - - int device; - - int height; - int width; - int groups; - int groupsize; - - uint32_t* cuda_qweight = NULL; - uint32_t* cuda_qzeros = NULL; - half* cuda_scales = NULL; - uint32_t* cuda_x_map = NULL; - - Q4Matrix - ( - const int _height, - const int _width, - const int _groups, - - uint32_t* _qweight, - uint32_t* _qzeros, - half* _scales, - uint32_t* _g_idx, - - const int _device - ); - - ~Q4Matrix(); - - void reconstruct(half* out); - -private: - - void make_sequential(const uint32_t* cpu_g_idx); - -}; - -void g_q4_keep_matrix(Q4Matrix* m); -void g_q4_free_matrices(); - -#endif \ No newline at end of file diff --git a/csrc/quantization/gptq/exllama_ext.cpp b/csrc/quantization/gptq/exllama_ext.cpp index 369fa0c05a8d..47e4b5539c60 100644 --- a/csrc/quantization/gptq/exllama_ext.cpp +++ b/csrc/quantization/gptq/exllama_ext.cpp @@ -1,5 +1,3 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - #include #include #include @@ -7,238 +5,99 @@ #include #include #include -#include "util.cuh" -#include "tuning.h" -#include "cuda_buffers.cuh" -#include "cuda_func/q4_matrix.cuh" -#include "cuda_func/q4_matmul.cuh" -#include "cuda_func/column_remap.cuh" - -// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a -// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of -// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console. - -void check_cuda(cudaError_t ret) -{ - switch (ret) - { - case cudaSuccess: - break; - - case cudaUnspecified: - printf(" **** Unspecified error\n"); - TORCH_CHECK(false, "CUDA error"); - break; - - default: - printf(" **** CUDA error\n"); \ - printf(" **** %s\n", cudaGetErrorString(ret)); \ - TORCH_CHECK(false, "CUDA error"); \ - break; - } -} + +#include "q_matrix.cuh" +#include "q_gemm.cuh" // Some decluttering macros -#define STRINGIFY_(__x) #__x -#define STRINGIFY(__x) STRINGIFY_(__x) #define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) #define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) #define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") #define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod)) -#define TORCH_CHECK_BUFFER_SIZE(__buffer, __minimum_size) TORCH_CHECK((__buffer).numel() >= __minimum_size, #__buffer " is too small") - -#define TORCH_CHECK_DEVICE_INDEX(__index) \ -do { \ - TORCH_CHECK(__index >= 0, "no device index"); \ - TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \ -} while(0) - -#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \ -do { \ - TORCH_CHECK_DTYPE(__w, kInt); \ - TORCH_CHECK_DTYPE(__w_scales, kHalf); \ - TORCH_CHECK_DTYPE(__w_zeros, kInt); \ - TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \ - TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \ - TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \ - TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \ -} while(0) - -int get_groupsize(torch::Tensor w, torch::Tensor w_zeros) -{ - int groupsize = w.size(0) * 8 / w_zeros.size(0); - TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]") - return groupsize; -} - - -// Tuning parameters - -ExLlamaTuning tuningParams; - -void gptq_set_tuning_params -( - int matmul_recons_thd, - bool matmul_fused_remap, - bool matmul_no_half2 -) -{ - tuningParams.matmul_recons_thd = matmul_recons_thd; - tuningParams.matmul_fused_remap = matmul_fused_remap; - tuningParams.matmul_no_half2 = matmul_no_half2; -} -// Release all unmanaged objects allocated by the extension - -void gptq_cleanup() -{ - cleanup_buffers_cuda(); - g_q4_free_matrices(); -} +// Quant matrix - -// Prepare buffers for forward pass - -void gptq_prepare_buffers +uintptr_t make_q_matrix ( - torch::Device device, - torch::Tensor temp_state, + torch::Tensor q_weight, + torch::Tensor q_perm, + torch::Tensor q_invperm, + torch::Tensor gptq_qzeros, + torch::Tensor gptq_scales, + torch::Tensor gptq_g_idx, torch::Tensor temp_dq ) { - int device_index = device.index(); - TORCH_CHECK_DEVICE_INDEX(device_index); - const at::cuda::OptionalCUDAGuard device_guard(device); + TORCH_CHECK_DTYPE(q_weight, kInt); + TORCH_CHECK_DTYPE_OPT(q_perm, kShort); + TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); + TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); + TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); + TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); + TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8); + TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1); - prepare_buffers_cuda - ( - device_index, - // buffer size used for sanity checks - temp_state.numel(), - (half*) temp_state.data_ptr(), - (half*) temp_dq.data_ptr() - ); -} + TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); + int device = q_weight.device().index(); + int width = q_weight.size(1); + int groups; + int height; -// Create Q4Matrix, return handle + groups = gptq_qzeros.size(0); + height = q_weight.size(0) * 8; -uintptr_t gptq_make_q4 -( - torch::Tensor qweight, - torch::Tensor qzeros, - torch::Tensor scales, - torch::Tensor g_idx, - int device -) -{ - TORCH_CHECK_DTYPE(qweight, kInt); - TORCH_CHECK_DTYPE(qzeros, kInt); - TORCH_CHECK_DTYPE(scales, kHalf); - TORCH_CHECK_DTYPE_OPT(g_idx, kInt); - TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8); - TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1); - TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1); - - int width = qweight.size(1); - int height = qweight.size(0) * 8; - int groups = qzeros.size(0); - - Q4Matrix* m = new Q4Matrix + TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer") + + QMatrix* m = new QMatrix ( + device, height, width, groups, - - (uint32_t*) qweight.data_ptr(), - (uint32_t*) qzeros.data_ptr(), - (half*) scales.data_ptr(), - g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(), - - device + (uint32_t*) q_weight.data_ptr(), + q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), + q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), + gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), + gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), + gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), + (half*) temp_dq.data_ptr() ); - g_q4_keep_matrix(m); return reinterpret_cast (m); } - -// Matmul half @ quant -> half - -void gptq_q4_matmul +void gemm_half_q_half ( - torch::Tensor x, - uintptr_t w, - torch::Tensor out + torch::Tensor a, + uintptr_t b, + torch::Tensor c, + bool force_cuda ) { - Q4Matrix* wm = reinterpret_cast (w); - - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(out, kHalf); - TORCH_CHECK_SHAPES(x, 0, out, 0, 1); - TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - int x_height = x.size(0); - - if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd) - { - q4_matmul_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr() - ); - } - else - { - q4_matmul_recons_cuda - ( - &tuningParams, - (half*) x.data_ptr(), - x_height, - wm, - (half*) out.data_ptr(), - at::cuda::getCurrentCUDABlasHandle() - ); - } -} + QMatrix* qm = reinterpret_cast (b); + TORCH_CHECK_DTYPE(a, kHalf); + TORCH_CHECK_DTYPE(c, kHalf); + TORCH_CHECK_SHAPES(a, 0, c, 0, 1); + TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes") + TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes") -// Remap columns in half tensor + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); -void gptq_column_remap -( - torch::Tensor x, - torch::Tensor x_new, - torch::Tensor x_map -) -{ - TORCH_CHECK_DTYPE(x, kHalf); - TORCH_CHECK_DTYPE(x_new, kHalf); - TORCH_CHECK_DTYPE(x_map, kInt); - TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1); - - int height = x.size(0); - int width = x.size(1); - - TORCH_CHECK_BUFFER_SIZE(x_new, height * width); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - - column_remap_cuda + gemm_half_q_half_cuda ( - (half*) x.data_ptr(), - (half*) x_new.data_ptr(), - height, - width, - (uint32_t*) x_map.data_ptr() + at::cuda::getCurrentCUDABlasHandle(), + (const half*) a.data_ptr(), + qm, + (half*) c.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + true, + NULL, + force_cuda ); -} \ No newline at end of file +} diff --git a/csrc/quantization/gptq/hip_compat.cuh b/csrc/quantization/gptq/hip_compat.cuh deleted file mode 100644 index 5cd2e8553ef6..000000000000 --- a/csrc/quantization/gptq/hip_compat.cuh +++ /dev/null @@ -1,49 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _hip_compat_cuh -#define _hip_compat_cuh - -// Workaround for a bug in hipamd, backported from upstream. -__device__ __forceinline__ __half __compat_hrcp(__half x) { - return __half_raw{ - static_cast<_Float16>(__builtin_amdgcn_rcph(static_cast<__half_raw>(x).data))}; -} - -__device__ __forceinline__ __half2 __compat_h2rcp(__half2 x) { - return _Float16_2{static_cast<_Float16>(__builtin_amdgcn_rcph(x.x)), - static_cast<_Float16>(__builtin_amdgcn_rcph(x.y))}; -} - -#define hrcp __compat_hrcp -#define h2rcp __compat_h2rcp - -// Workaround for hipify_python using rocblas instead of hipblas. -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, - hipblasOperation_t transA, - hipblasOperation_t transB, - int m, - int n, - int k, - const half* alpha, - const half* AP, - int lda, - const half* BP, - int ldb, - const half* beta, - half* CP, - int ldc) { - return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); -} - -#define rocblas_handle hipblasHandle_t -#define rocblas_operation_none HIPBLAS_OP_N -#define rocblas_get_stream hipblasGetStream -#define rocblas_set_stream hipblasSetStream -#define rocblas_hgemm __compat_hipblasHgemm - -#endif diff --git a/csrc/quantization/gptq/matrix.cuh b/csrc/quantization/gptq/matrix.cuh deleted file mode 100644 index 2fd5ab0b36cd..000000000000 --- a/csrc/quantization/gptq/matrix.cuh +++ /dev/null @@ -1,294 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _matrix_cuh -#define _matrix_cuh - -#include -#include - -class MatrixView_half -{ -public: - const half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } - __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } -}; - -class MatrixView_half_rw -{ -public: - half* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } - __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } - __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } - __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } - __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } -}; - -class MatrixView_q4_row -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (column & 0x07) * 4; - return (data[row * width / 8 + column / 8] >> shift) & 0x0f; - } -}; - -class MatrixView_q4_column -{ -public: - const uint32_t* data; - const int height; - const int width; - - __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) - : data(data), height(height), width(width) - { } - - __device__ __forceinline__ int item(int row, int column) const - { - int shift = (row & 0x07) * 4; - return (data[row / 8 * width + column] >> shift) & 0x0f; - } - - __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } - __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } -}; - -// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale - -__device__ __forceinline__ half2 dot_product_8 -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - -// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently) -// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff]; -// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff]; -// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ]; - - half2 tmp = __hmul2(*h_ptr++, v_01); - tmp = __hfma2(*h_ptr++, v_23, tmp); - tmp = __hfma2(*h_ptr++, v_45, tmp); - tmp = __hfma2(*h_ptr++, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count -) -{ - const half* h_ptr = h_.item_ptr(h_row, h_column); - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(*h_ptr++, v_0); - tmp = __hfma(*h_ptr++, v_1, tmp); - tmp = __hfma(*h_ptr++, v_2, tmp); - tmp = __hfma(*h_ptr++, v_3, tmp); - tmp = __hfma(*h_ptr++, v_4, tmp); - tmp = __hfma(*h_ptr++, v_5, tmp); - tmp = __hfma(*h_ptr++, v_6, tmp); - tmp = __hfma(*h_ptr++, v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map - -__device__ __forceinline__ half2 dot_product_8_x_map -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - - half h_0 = h_ptr[*x_map_ptr++]; - half h_1 = h_ptr[*x_map_ptr++]; - half h_2 = h_ptr[*x_map_ptr++]; - half h_3 = h_ptr[*x_map_ptr++]; - half h_4 = h_ptr[*x_map_ptr++]; - half h_5 = h_ptr[*x_map_ptr++]; - half h_6 = h_ptr[*x_map_ptr++]; - half h_7 = h_ptr[*x_map_ptr++]; - - half2 h_01 = __halves2half2(h_0, h_1); - half2 h_23 = __halves2half2(h_2, h_3); - half2 h_45 = __halves2half2(h_4, h_5); - half2 h_67 = __halves2half2(h_6, h_7); - - half2 tmp = __hmul2(h_01, v_01); - tmp = __hfma2(h_23, v_23, tmp); - tmp = __hfma2(h_45, v_45, tmp); - tmp = __hfma2(h_67, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_x_map_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, // + 1 (!!) - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); - tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - -#endif diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh new file mode 100644 index 000000000000..d1264d896bfc --- /dev/null +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -0,0 +1,121 @@ +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "qdq_util.cuh" + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +#endif diff --git a/csrc/quantization/gptq/old_matmul.cpp b/csrc/quantization/gptq/old_matmul.cpp deleted file mode 100644 index f5b4be77701e..000000000000 --- a/csrc/quantization/gptq/old_matmul.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include -#include -#include - -void vecquant4matmul_cuda( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor scales, - torch::Tensor zeros, - torch::Tensor g_idx -); - -void gptq_descact_matmul( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor scales, - torch::Tensor zeros, - torch::Tensor g_idx -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); -} diff --git a/csrc/quantization/gptq/old_matmul_kernel.cu b/csrc/quantization/gptq/old_matmul_kernel.cu index 68c39fca8680..a79bf4b3edc0 100644 --- a/csrc/quantization/gptq/old_matmul_kernel.cu +++ b/csrc/quantization/gptq/old_matmul_kernel.cu @@ -1,9 +1,10 @@ #include #include +#include #include #include #include -#include "cu_compat.cuh" +#include "compat.cuh" const int BLOCKWIDTH = 256; const int BLOCKHEIGHT = 32; @@ -108,4 +109,17 @@ void vecquant4matmul_cuda( ); }) ); +} + +void gptq_descact_matmul( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor scales, + torch::Tensor zeros, + torch::Tensor g_idx +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); + vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); } \ No newline at end of file diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu new file mode 100644 index 000000000000..5fa657bdb149 --- /dev/null +++ b/csrc/quantization/gptq/q_gemm.cu @@ -0,0 +1,168 @@ +#include "q_gemm.cuh" +#include "matrix_view.cuh" + +#include "qdq_4.cuh" + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define CLEAR_N_SIZE 256 +#define MAX_Q_GEMM_ROWS 50 +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#include "q_gemm_kernel_gptq.cuh" + +#if defined(USE_ROCM) +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +void gemm_half_q_half_cuda_part +( + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + bool clear +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); + + kernel<<>> + ( + a, + b->cuda_q_weight, + b->cuda_gptq_qzeros, + b->cuda_gptq_scales, + c, + size_m, + size_n, + size_k, + b->groups, + b->groupsize, + b->cuda_q_perm, + clear + ); +} + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear, + half* temp_dq, + bool force_cuda +) +{ + if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) + { + + // Reconstruct FP16 matrix, then cuBLAS + + if (!temp_dq) temp_dq = b->temp_dq; + b->reconstruct(temp_dq); + + //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); + + const half alpha = __float2half(1.0f); + const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); + cublasHgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha, temp_dq, size_n, + a, size_k, + &beta, c, size_n); + + } + else + { + // Quantized matmul + + //if (clear) clear_tensor_cuda(c, size_m, size_n); + + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) + { + gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear); + } + + if (last_chunk_size) + { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear); + } + } +} + +__global__ void clear_kernel +( + half* __restrict__ c, + const int size_m, + const int size_n +) +{ + int m = blockIdx.y; + int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8; + if (n >= size_n) return; + int4* c_ptr = (int4*)(c + m * size_n + n); + *c_ptr = {}; +} + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +) +{ + return; + dim3 blockDim, gridDim; + blockDim.x = CLEAR_N_SIZE; + blockDim.y = 1; + gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); + gridDim.y = size_m; + clear_kernel<<>>(c, size_m, size_n); +} diff --git a/csrc/quantization/gptq/q_gemm.cuh b/csrc/quantization/gptq/q_gemm.cuh new file mode 100644 index 000000000000..c69f1a709689 --- /dev/null +++ b/csrc/quantization/gptq/q_gemm.cuh @@ -0,0 +1,33 @@ +#ifndef _q_gemm_cuh +#define _q_gemm_cuh + +#include +#include +#include +#include +#include + +#include "q_matrix.cuh" + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + QMatrix* b, + half* c, + int size_m, + int size_n, + int size_k, + bool clear = false, + half* reconstruct = NULL, + bool force_cuda = false +); + +void clear_tensor_cuda +( + half* c, + int size_m, + int size_n +); + +#endif \ No newline at end of file diff --git a/csrc/quantization/gptq/q_gemm_kernel_gptq.cuh b/csrc/quantization/gptq/q_gemm_kernel_gptq.cuh new file mode 100644 index 000000000000..29c86e9555da --- /dev/null +++ b/csrc/quantization/gptq/q_gemm_kernel_gptq.cuh @@ -0,0 +1,217 @@ +#include "compat.cuh" + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int, + const uint16_t*, + const bool +); + +template +__global__ void gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int groupsize, + const uint16_t* __restrict__ b_q_perm, + const bool clear +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + + if (n >= size_n) return; + + if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + +// __syncthreads(); + + // Column result + + float block_c[m_count][4] = {}; + + // Dequantize and multiply + + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) +{ + #if BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel; + #endif + return NULL; +} diff --git a/csrc/quantization/gptq/q_matrix.cu b/csrc/quantization/gptq/q_matrix.cu new file mode 100644 index 000000000000..e6d48e588d2b --- /dev/null +++ b/csrc/quantization/gptq/q_matrix.cu @@ -0,0 +1,341 @@ +#include "q_matrix.cuh" +#include "matrix_view.cuh" + +#include "qdq_4.cuh" + +#define BLOCK_KN_SIZE 128 + +#define THREADS_X 32 +#define THREADS_Y 32 +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +// Shuffle quantized data on load + +__global__ void shuffle_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } +} + + +// QMatrix constructor + +QMatrix::QMatrix +( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq +) : + device(_device), + height(_height), + width(_width), + groups(_groups), + temp_dq(_temp_dq) +{ + cudaSetDevice(device); + + cuda_q_weight = _q_weight; + cuda_q_perm = _q_perm; + cuda_q_invperm = _q_invperm; + cuda_gptq_qzeros = _gptq_qzeros; + cuda_gptq_scales = _gptq_scales; + + is_gptq = true; + + groupsize = 1; + while (groupsize * groups < height) groupsize *= 2; + + if (_gptq_g_idx) make_sequential(_gptq_g_idx); + + // Shuffle quantized data + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + + shuffle_kernel<<>>(cuda_q_weight, height, width); +} + + +// Reconstruct b[k,n] (GPTQ) + +__global__ void reconstruct_gptq_kernel +( + const uint32_t* __restrict__ b_q_weight, + const uint16_t* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groupsize, + const int groups, + half* __restrict__ b +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + + __shared__ uint16_t perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + +void QMatrix::reconstruct(half* out) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + + reconstruct_gptq_kernel<<>> + ( + cuda_q_weight, + cuda_q_perm, + cuda_gptq_qzeros, + cuda_gptq_scales, + height, + width, + groupsize, + groups, + out + ); +} + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const uint16_t* __restrict__ q_perm, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + + int w_new2_row = blockIdx.y; + + int q_perm_idx = w_new2_row << 3; + + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + +void QMatrix::make_sequential(const uint32_t* cpu_g_idx) +{ + uint32_t* cuda_new_qweight = NULL; + cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); + + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + + // Group histogram + + for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; + + // Group map + + for (int i = 0, acc = 0; i < groups; i++) + { + short tmp = cpu_g_idx_map[i]; + cpu_g_idx_map[i] = acc; + acc += tmp; + } + + // X map (inverse) + + for (int row = 0; row < height; row++) + { + uint32_t target_group = cpu_g_idx[row]; + uint32_t target_row = cpu_g_idx_map[target_group]; + cpu_g_idx_map[target_group]++; + cpu_x_map_inv[row] = target_row; + } + + // X map + + for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; + + // Reduce to uint16_t + + uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map; + uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv; + for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row]; + for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row]; + + // Move to CUDA + + cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); + + // Rearrange rows in w + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 8; + + make_sequential_kernel<<>> + ( + cuda_q_weight, + cuda_new_qweight, + cuda_q_perm, + height / 8, + width + ); + + // Replace qweights + + cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + + // Cleanup + + cudaDeviceSynchronize(); + + cudaFree(cuda_new_qweight); + free(cpu_g_idx_map); + free(cpu_x_map); + free(cpu_x_map_inv); +} diff --git a/csrc/quantization/gptq/q_matrix.cuh b/csrc/quantization/gptq/q_matrix.cuh new file mode 100644 index 000000000000..86124ffdc217 --- /dev/null +++ b/csrc/quantization/gptq/q_matrix.cuh @@ -0,0 +1,58 @@ +#ifndef _q_matrix_cuh +#define _q_matrix_cuh + +#include +#include +#include +#include + +#define MAX_SUPERGROUPS 16 + +class QMatrix +{ +public: + + int device; + bool is_gptq; + + int height; + int width; + int groups; + int groupsize; + + uint32_t* cuda_q_weight = NULL; + uint16_t* cuda_q_perm = NULL; + uint16_t* cuda_q_invperm = NULL; + uint32_t* cuda_gptq_qzeros = NULL; + half* cuda_gptq_scales = NULL; + + half* temp_dq; + + QMatrix + ( + const int _device, + const int _height, + const int _width, + const int _groups, + + uint32_t* _q_weight, + uint16_t* _q_perm, + uint16_t* _q_invperm, + + uint32_t* _gptq_qzeros, + half* _gptq_scales, + uint32_t* _gptq_g_idx, + + half* _temp_dq + ); + + ~QMatrix(); + + void reconstruct(half* out); + void make_sequential(const uint32_t* cpu_g_idx); + +private: + +}; + +#endif diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh new file mode 100644 index 000000000000..a7bde6d30508 --- /dev/null +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -0,0 +1,222 @@ +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} + +#else + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1)[2], + half2 (&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z = __hmul(z, scale); + z1[0] = __half2half2(z); + y1[0] = __half2half2(scale); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1)[2], + half2(&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z1[0] = __half2half2(z); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1)[2], + half2 (&y1)[2], + int stride, + bool scaled +) +{ + half2 dqh2[8]; + + uint32_t qa = q_0; + for (int i = 0; i < 4; i++) + { + half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; + half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; + dqh2[i] = __halves2half2(d0, d1); + } + + if (scaled) + { + dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); + dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); + dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); + dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); + } + else + { + dq[0] = __hadd2(dqh2[0], z1[0]); + dq[1] = __hadd2(dqh2[1], z1[0]); + dq[2] = __hadd2(dqh2[2], z1[0]); + dq[3] = __hadd2(dqh2[3], z1[0]); + } +} + +#endif diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh new file mode 100644 index 000000000000..71657191b911 --- /dev/null +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -0,0 +1,51 @@ +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +#endif diff --git a/csrc/quantization/gptq/tuning.h b/csrc/quantization/gptq/tuning.h deleted file mode 100644 index 770ca46aa7c8..000000000000 --- a/csrc/quantization/gptq/tuning.h +++ /dev/null @@ -1,13 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _tuning_h -#define _tuning_h - -struct ExLlamaTuning -{ - int matmul_recons_thd; - bool matmul_fused_remap; - bool matmul_no_half2; -}; - -#endif diff --git a/csrc/quantization/gptq/util.cuh b/csrc/quantization/gptq/util.cuh deleted file mode 100644 index 7b397573214b..000000000000 --- a/csrc/quantization/gptq/util.cuh +++ /dev/null @@ -1,33 +0,0 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama - -#ifndef _util_cuh -#define _util_cuh - -#include -#include -#include -#include - -#if defined(USE_ROCM) -#define cudaUnspecified hipErrorUnknown -#else -#define cudaUnspecified cudaErrorApiFailureBase -#endif - -// React to failure on return code != cudaSuccess - -#define _cuda_check(fn) \ -do { \ - {_cuda_err = fn;} \ - if (_cuda_err != cudaSuccess) goto _cuda_fail; \ -} while(false) - -// React to failure on return code == 0 - -#define _alloc_check(fn) \ -do { \ - if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \ - else _cuda_err = cudaSuccess; \ -} while(false) - -#endif diff --git a/setup.py b/setup.py index f07a813d07b4..781435154ec3 100644 --- a/setup.py +++ b/setup.py @@ -201,11 +201,8 @@ def get_torch_arch_list() -> Set[str]: "csrc/quantization.cpp", "csrc/quantization/awq/gemm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/gptq/exllama_ext.cpp", - "csrc/quantization/gptq/cuda_buffers.cu", - "csrc/quantization/gptq/cuda_func/column_remap.cu", - "csrc/quantization/gptq/cuda_func/q4_matmul.cu", - "csrc/quantization/gptq/cuda_func/q4_matrix.cu", - "csrc/quantization/gptq/old_matmul.cpp", + "csrc/quantization/gptq/q_matrix.cu", + "csrc/quantization/gptq/q_gemm.cu", "csrc/quantization/gptq/old_matmul_kernel.cu" ], extra_compile_args={ diff --git a/vllm/model_executor/layers/quantized_linear/gptq.py b/vllm/model_executor/layers/quantized_linear/gptq.py index d54cbe59e9c1..dd84c8c4dab5 100644 --- a/vllm/model_executor/layers/quantized_linear/gptq.py +++ b/vllm/model_executor/layers/quantized_linear/gptq.py @@ -8,6 +8,29 @@ RowParallelLinear) +class ExLlamaV2DeviceTensors: + + def __init__(self, device_idx, scratch_bytes): + self.device_idx = device_idx + self.scratch_bytes = scratch_bytes + self.scratch = None + + def prepare(self): + self.scratch = torch.empty( + (self.scratch_bytes // 2, ), + dtype=torch.half, + device=f"cuda:{self.device_idx}", + ) + + def get_scratch_slice(self, size_bytes): + if self.scratch is None: + self.prepare() + size_bytes = ((size_bytes + 127) // 128) * 128 + size_half = size_bytes // 2 + scratch_slice = self.scratch.narrow(0, 0, size_half) + return scratch_slice + + class GPTQLinear(torch.nn.Module): def __init__(self, @@ -69,19 +92,36 @@ def __init__(self, else: self.register_parameter("bias", None) - def post_init(self): + def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None - # make_q4 segfaults if g_idx is not on cpu in the act-order case. - # In the non act-order case, None needs to be passed for g_idx. + none_tensor = torch.empty((1, 1), device="meta") + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) if not self.quant_config.desc_act: - g_idx = torch.empty((1, 1), device="meta") + self.q4 = quantization_ops.make_q_matrix( + self.qweight, + none_tensor, + none_tensor, + self.qzeros, + self.scales, + none_tensor, + temp_dq, + ) else: - g_idx = self.g_idx.to("cpu") - self.q4 = quantization_ops.gptq_make_q4(self.qweight, self.qzeros, - self.scales, g_idx, - self.qweight.device.index) + self.q_perm = torch.empty((self.input_size, ), + dtype=torch.short, + device=self.qweight.device) + self.q_invperm = torch.empty_like(self.q_perm) + self.q4 = quantization_ops.make_q_matrix( + self.qweight, + self.q_perm, + self.q_invperm, + self.qzeros, + self.scales, + self.g_idx.cpu(), + temp_dq, + ) def forward(self, input_): out_shape = input_.shape[:-1] + (self.qweight.shape[-1], ) @@ -89,12 +129,21 @@ def forward(self, input_): output = torch.empty((reshaped_x.shape[0], self.qweight.shape[-1]), dtype=torch.float16, device=input_.device) - quantization_ops.gptq_q4_matmul(reshaped_x, self.q4, output) + quantization_ops.gemm_half_q_half(reshaped_x, self.q4, output, False) output = output.reshape(out_shape) output = output + self.bias if self.bias is not None else output return output + def temp_dq_size(self): + return self.input_size * self.output_size * 2 + 128 + + def temp_fwd_size(self, max_tokens): + return self.output_size * max_tokens * 4 + 128 + + def scratch_space_fixed(self, max_tokens): + return self.temp_dq_size() + self.temp_fwd_size(max_tokens) + class GPTQColumnParallelLinear(ColumnParallelLinear): @@ -143,19 +192,36 @@ def create_weights(self, dtype: torch.dtype) -> None: requires_grad=False, ) - def post_init(self): + def post_init(self, temp_dq): assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None - # make_q4 segfaults if g_idx is not on cpu in the act-order case. - # In the non act-order case, None needs to be passed for g_idx. + none_tensor = torch.empty((1, 1), device="meta") + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) if not self.quant_config.desc_act: - g_idx = torch.empty((1, 1), device="meta") + self.q4 = quantization_ops.make_q_matrix( + self.qweight, + none_tensor, + none_tensor, + self.qzeros, + self.scales, + none_tensor, + temp_dq, + ) else: - g_idx = self.g_idx.to("cpu") - self.q4 = quantization_ops.gptq_make_q4(self.qweight, self.qzeros, - self.scales, g_idx, - self.qweight.device.index) + self.q_perm = torch.empty((self.input_size, ), + dtype=torch.short, + device=self.qweight.device) + self.q_invperm = torch.empty_like(self.q_perm) + self.q4 = quantization_ops.make_q_matrix( + self.qweight, + self.q_perm, + self.q_invperm, + self.qzeros, + self.scales, + self.g_idx.cpu(), + temp_dq, + ) def apply_weights( self, @@ -167,11 +233,20 @@ def apply_weights( output = torch.empty((reshaped_x.shape[0], self.qweight.shape[-1]), dtype=torch.float16, device=x.device) - quantization_ops.gptq_q4_matmul(reshaped_x, self.q4, output) + quantization_ops.gemm_half_q_half(reshaped_x, self.q4, output, False) if bias is not None: output = output + bias return output.reshape(out_shape) + def temp_dq_size(self): + return self.input_size * self.output_size_per_partition * 2 + 128 + + def temp_fwd_size(self, max_tokens): + return self.output_size_per_partition * max_tokens * 4 + 128 + + def scratch_space_fixed(self, max_tokens): + return self.temp_dq_size() + self.temp_fwd_size(max_tokens) + class GPTQRowParallelLinear(RowParallelLinear): @@ -228,21 +303,38 @@ def create_weights(self, dtype: torch.dtype) -> None: requires_grad=False, ) - def post_init(self): + def post_init(self, temp_dq): if not self.use_exllama: return assert self.qweight.device.type == "cuda" assert self.qweight.device.index is not None - # make_q4 segfaults if g_idx is not on cpu in the act-order case. - # In the non act-order case, None needs to be passed for g_idx. + none_tensor = torch.empty((1, 1), device="meta") + temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size()) if not self.quant_config.desc_act: - g_idx = torch.empty((1, 1), device="meta") + self.q4 = quantization_ops.make_q_matrix( + self.qweight, + none_tensor, + none_tensor, + self.qzeros, + self.scales, + none_tensor, + temp_dq, + ) else: - g_idx = self.g_idx.to("cpu") - self.q4 = quantization_ops.gptq_make_q4(self.qweight, self.qzeros, - self.scales, g_idx, - self.qweight.device.index) + self.q_perm = torch.empty((self.input_size, ), + dtype=torch.short, + device=self.qweight.device) + self.q_invperm = torch.empty_like(self.q_perm) + self.q4 = quantization_ops.make_q_matrix( + self.qweight, + self.q_perm, + self.q_invperm, + self.qzeros, + self.scales, + self.g_idx.cpu(), + temp_dq, + ) def apply_weights(self, x: torch.Tensor) -> torch.Tensor: out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) @@ -252,7 +344,8 @@ def apply_weights(self, x: torch.Tensor) -> torch.Tensor: output = torch.empty((reshaped_x.shape[0], self.qweight.shape[-1]), dtype=torch.float16, device=x.device) - quantization_ops.gptq_q4_matmul(reshaped_x, self.q4, output) + quantization_ops.gemm_half_q_half(reshaped_x, self.q4, output, + False) else: output = torch.zeros((reshaped_x.shape[0], self.qweight.shape[-1]), dtype=torch.float32, @@ -263,3 +356,18 @@ def apply_weights(self, x: torch.Tensor) -> torch.Tensor: self.qzeros, self.g_idx) output = output.half() return output.reshape(out_shape) + + def temp_dq_size(self): + if not self.use_exllama: + return 0 + return self.input_size_per_partition * self.output_size * 2 + 128 + + def temp_fwd_size(self, max_tokens): + if not self.use_exllama: + return 0 + return self.output_size * max_tokens * 4 + 128 + + def scratch_space_fixed(self, max_tokens): + if not self.use_exllama: + return 0 + return self.temp_dq_size() + self.temp_fwd_size(max_tokens) diff --git a/vllm/model_executor/layers/quantized_linear/utils.py b/vllm/model_executor/layers/quantized_linear/utils.py index 679084824cfb..6f2b54796450 100644 --- a/vllm/model_executor/layers/quantized_linear/utils.py +++ b/vllm/model_executor/layers/quantized_linear/utils.py @@ -2,93 +2,47 @@ import torch -from vllm import quantization_ops from vllm.model_executor.layers.quantized_linear.gptq import ( - GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear) + GPTQColumnParallelLinear, + GPTQRowParallelLinear, + GPTQLinear, + ExLlamaV2DeviceTensors, +) -def quant_post_init(model, max_input_length: Optional[int] = None): +def quant_post_init(model, max_tokens: Optional[int] = None): """ - The max_input_length argument is specific to the exllama backend, + The max_tokens argument is specific to the exllama backend, that requires to initialize a buffer temp_state. """ - device_to_buffers_size = {} + fixed_bytes = {} model_uses_exllama = False - use_act_order = False for _, submodule in model.named_modules(): if isinstance(submodule, (GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)) and submodule.use_exllama: model_uses_exllama = True device = submodule.qweight.device - if device not in device_to_buffers_size: - device_to_buffers_size[device] = { - "max_dq_buffer_size": 1, - "max_inner_outer_dim": 1 - } - - device_to_buffers_size[device]["max_dq_buffer_size"] = max( - device_to_buffers_size[device]["max_dq_buffer_size"], - submodule.qweight.numel() * 8) - - in_features = submodule.input_size_per_partition if isinstance( - submodule, GPTQRowParallelLinear) else submodule.input_size - out_features = submodule.output_size_per_partition if isinstance( - submodule, GPTQColumnParallelLinear) else submodule.output_size - if submodule.quant_config.desc_act: - use_act_order = True - device_to_buffers_size[device]["max_inner_outer_dim"] = max( - device_to_buffers_size[device]["max_inner_outer_dim"], - in_features, out_features) + scratch_fixed = submodule.scratch_space_fixed(max_tokens) + fixed_bytes[device] = max(scratch_fixed, + fixed_bytes.get(device, 0)) if model_uses_exllama: - device_to_buffers = {} - max_input_len = max_input_length if use_act_order else 1 - for device, buffers_size in device_to_buffers_size.items(): - # The temp_state buffer is required to reorder X in the act-order - # case. The temp_dq buffer is required to dequantize weights when - # using cuBLAS, typically for the prefill. - device_to_buffers[device] = { - "temp_state": - torch.zeros( - (max_input_len, buffers_size["max_inner_outer_dim"]), - dtype=torch.float16, - device=device), - "temp_dq": - torch.zeros((1, buffers_size["max_dq_buffer_size"]), - dtype=torch.float16, - device=device), - "max_dq_buffer_size": - buffers_size["max_dq_buffer_size"], - "max_inner_outer_dim": - buffers_size["max_inner_outer_dim"], - } - - # Buffers need to be persistent to avoid any bug. - model.device_to_buffers = device_to_buffers + device_tensors = {} + for device, scratch_bytes in fixed_bytes.items(): + device_tensors[device] = ExLlamaV2DeviceTensors( + device.index, scratch_bytes) - for device, buffers in model.device_to_buffers.items(): - quantization_ops.gptq_prepare_buffers(device, - buffers["temp_state"], - buffers["temp_dq"]) + # have persistent buffers, otherwise we will get OOM + model.device_tensors = device_tensors - # Using the default from exllama repo here. - matmul_recons_thd = 8 - matmul_fused_remap = False - matmul_no_half2 = False - quantization_ops.gptq_set_tuning_params(matmul_recons_thd, - matmul_fused_remap, - matmul_no_half2) - - # The buffers need to have been initialized first before calling - # make_q4. for _, submodule in model.named_modules(): - if isinstance( - submodule, - (GPTQColumnParallelLinear, GPTQRowParallelLinear, GPTQLinear)): - submodule.post_init() - - torch.cuda.empty_cache() + if isinstance(submodule, + (GPTQColumnParallelLinear, GPTQRowParallelLinear, + GPTQLinear)) and submodule.use_exllama: + device = submodule.qweight.device + submodule.post_init(temp_dq=model.device_tensors[device]) + torch.cuda.empty_cache() return model diff --git a/vllm/model_executor/quantization_utils/squeezellm.py b/vllm/model_executor/quantization_utils/squeezellm.py index 8a1db3e23321..65997dcba627 100644 --- a/vllm/model_executor/quantization_utils/squeezellm.py +++ b/vllm/model_executor/quantization_utils/squeezellm.py @@ -56,10 +56,8 @@ def get_packed_tensors(cls) -> Dict[str, int]: def get_transposed_tensor_names(cls) -> List[str]: return ["qweight"] - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: + def get_col_parallel_tensor_names(self) -> List[str]: return ["qweight", "lookup_table"] - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: + def get_row_parallel_tensor_names(self) -> List[str]: return ["qweight"] From 2d8dc1d7eda6282b9f6d7a175559a7bc7f78edb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Tue, 14 Nov 2023 14:14:18 +0800 Subject: [PATCH 09/18] Fix chatglm --- vllm/model_executor/models/chatglm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 13ed652b3b3a..6b45e9401f15 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -368,6 +368,7 @@ def load_weights( state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + packed_dim = None is_transposed = False if self.quant_config is not None: packed_dim = self.quant_config.get_packed_dim(name) From 17b6f2b86fc95de186ea355e7dde3ca073f3be12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Fri, 1 Dec 2023 22:57:01 +0800 Subject: [PATCH 10/18] Fix phi model --- vllm/model_executor/models/phi_1_5.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index fbf7aa0a1491..6e2c1fa507cd 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -310,6 +310,8 @@ def load_weights(self, continue # pylint: disable=E1136 + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 62bd8ce7ec4b1a595cfc009fdd584249e9b6afd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Sat, 2 Dec 2023 12:45:59 +0800 Subject: [PATCH 11/18] move post init to first forward pass to make code cleaner --- csrc/quantization.cpp | 8 +- csrc/quantization/gptq/exllama_ext.cpp | 11 +-- csrc/quantization/gptq/q_gemm.cu | 2 - csrc/quantization/gptq/q_matrix.cu | 7 +- csrc/quantization/gptq/q_matrix.cuh | 6 +- .../layers/quantization/gptq.py | 75 ++++++++----------- .../layers/quantization/utils.py | 66 ---------------- vllm/model_executor/model_loader.py | 5 +- vllm/worker/worker.py | 3 +- 9 files changed, 42 insertions(+), 141 deletions(-) delete mode 100644 vllm/model_executor/layers/quantization/utils.py diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index b9919d868f16..fd66a84e3985 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -14,16 +14,14 @@ uintptr_t make_q_matrix( torch::Tensor q_invperm, torch::Tensor gptq_qzeros, torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx, - torch::Tensor temp_dq -); + torch::Tensor gptq_g_idx); void gemm_half_q_half( torch::Tensor a, uintptr_t b, torch::Tensor c, - bool force_cuda -); + torch::Tensor temp_dq, + bool force_cuda); void gptq_descact_matmul( torch::Tensor vec, diff --git a/csrc/quantization/gptq/exllama_ext.cpp b/csrc/quantization/gptq/exllama_ext.cpp index 47e4b5539c60..ac74f6d1d64b 100644 --- a/csrc/quantization/gptq/exllama_ext.cpp +++ b/csrc/quantization/gptq/exllama_ext.cpp @@ -26,8 +26,7 @@ uintptr_t make_q_matrix torch::Tensor q_invperm, torch::Tensor gptq_qzeros, torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx, - torch::Tensor temp_dq + torch::Tensor gptq_g_idx ) { TORCH_CHECK_DTYPE(q_weight, kInt); @@ -49,8 +48,6 @@ uintptr_t make_q_matrix groups = gptq_qzeros.size(0); height = q_weight.size(0) * 8; - TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer") - QMatrix* m = new QMatrix ( device, @@ -62,8 +59,7 @@ uintptr_t make_q_matrix q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), - gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), - (half*) temp_dq.data_ptr() + gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr() ); return reinterpret_cast (m); @@ -74,6 +70,7 @@ void gemm_half_q_half torch::Tensor a, uintptr_t b, torch::Tensor c, + torch::Tensor temp_dq, bool force_cuda ) { @@ -97,7 +94,7 @@ void gemm_half_q_half c.size(1), // n a.size(1), // k true, - NULL, + (half*) temp_dq.data_ptr(), force_cuda ); } diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 5fa657bdb149..6f86718cf021 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -98,8 +98,6 @@ void gemm_half_q_half_cuda { // Reconstruct FP16 matrix, then cuBLAS - - if (!temp_dq) temp_dq = b->temp_dq; b->reconstruct(temp_dq); //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); diff --git a/csrc/quantization/gptq/q_matrix.cu b/csrc/quantization/gptq/q_matrix.cu index e6d48e588d2b..23e40309dfca 100644 --- a/csrc/quantization/gptq/q_matrix.cu +++ b/csrc/quantization/gptq/q_matrix.cu @@ -41,15 +41,12 @@ QMatrix::QMatrix uint32_t* _gptq_qzeros, half* _gptq_scales, - uint32_t* _gptq_g_idx, - - half* _temp_dq + uint32_t* _gptq_g_idx ) : device(_device), height(_height), width(_width), - groups(_groups), - temp_dq(_temp_dq) + groups(_groups) { cudaSetDevice(device); diff --git a/csrc/quantization/gptq/q_matrix.cuh b/csrc/quantization/gptq/q_matrix.cuh index 86124ffdc217..3fedbc0823f9 100644 --- a/csrc/quantization/gptq/q_matrix.cuh +++ b/csrc/quantization/gptq/q_matrix.cuh @@ -26,8 +26,6 @@ public: uint32_t* cuda_gptq_qzeros = NULL; half* cuda_gptq_scales = NULL; - half* temp_dq; - QMatrix ( const int _device, @@ -41,9 +39,7 @@ public: uint32_t* _gptq_qzeros, half* _gptq_scales, - uint32_t* _gptq_g_idx, - - half* _temp_dq + uint32_t* _gptq_g_idx ); ~QMatrix(); diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 01e2f5077754..a774f18074a8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -174,16 +174,44 @@ def apply_weights(self, weights: Dict[str, torch.Tensor], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - #q4 = weights["q4"] qweight = weights["qweight"] + height, width = weights["qweight"].shape out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if weights["use_exllama"]: + if "q4" not in weights: + if not self.quant_config.desc_act: + none_tensor = torch.empty((1, 1), device="meta") + weights["q4"] = quantization_ops.make_q_matrix( + weights["qweight"], + none_tensor, + none_tensor, + weights["qzeros"], + weights["scales"], + none_tensor, + ) + else: + weights["q_perm"] = torch.empty( + (height * self.quant_config.pack_factor, ), + dtype=torch.short, + device=weights["qweight"].device) + weights["q_invperm"] = torch.empty_like(weights["q_perm"]) + weights["q4"] = quantization_ops.make_q_matrix( + weights["qweight"], + weights["q_perm"], + weights["q_invperm"], + weights["qzeros"], + weights["scales"], + weights["g_idx"].cpu(), + ) + temp_dq = torch.empty((height * self.quant_config.pack_factor, width), + dtype=torch.float16, + device=x.device) output = torch.empty((reshaped_x.shape[0], qweight.shape[-1]), dtype=torch.float16, device=x.device) quantization_ops.gemm_half_q_half(reshaped_x, weights["q4"], output, - False) + temp_dq, False) else: output = torch.zeros((reshaped_x.shape[0], qweight.shape[-1]), dtype=torch.float32, @@ -196,46 +224,3 @@ def apply_weights(self, if bias is not None: output = output + bias return output.reshape(out_shape) - - def temp_dq_size(self, input_size, output_size): - return input_size * output_size * 2 + 128 - - def temp_fwd_size(self, output_size, max_tokens): - return output_size * max_tokens * 4 + 128 - - def scratch_space_fixed(self, input_size, output_size, max_tokens): - return self.temp_dq_size(input_size, output_size) + self.temp_fwd_size( - output_size, max_tokens) - - def post_init(self, linear_weights, temp_dq): - if not linear_weights["use_exllama"]: - return - none_tensor = torch.empty((1, 1), device="meta") - height, width = linear_weights["qweight"].shape - temp_dq = temp_dq.get_scratch_slice( - self.temp_dq_size(height * self.quant_config.pack_factor, width)) - if not self.quant_config.desc_act: - linear_weights["q4"] = quantization_ops.make_q_matrix( - linear_weights["qweight"], - none_tensor, - none_tensor, - linear_weights["qzeros"], - linear_weights["scales"], - none_tensor, - temp_dq, - ) - else: - linear_weights["q_perm"] = torch.empty( - (height * self.quant_config.pack_factor, ), - dtype=torch.short, - device=linear_weights["qweight"].device) - linear_weights["q_invperm"] = torch.empty_like(linear_weights["q_perm"]) - linear_weights["q4"] = quantization_ops.make_q_matrix( - linear_weights["qweight"], - linear_weights["q_perm"], - linear_weights["q_invperm"], - linear_weights["qzeros"], - linear_weights["scales"], - linear_weights["g_idx"].cpu(), - temp_dq, - ) diff --git a/vllm/model_executor/layers/quantization/utils.py b/vllm/model_executor/layers/quantization/utils.py deleted file mode 100644 index d970153070c5..000000000000 --- a/vllm/model_executor/layers/quantization/utils.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Optional - -import torch - - -class ExLlamaV2DeviceTensors: - - def __init__(self, device_idx, scratch_bytes): - self.device_idx = device_idx - self.scratch_bytes = scratch_bytes - self.scratch = None - - def prepare(self): - self.scratch = torch.empty( - (self.scratch_bytes // 2, ), - dtype=torch.half, - device=f"cuda:{self.device_idx}", - ) - - def get_scratch_slice(self, size_bytes): - if self.scratch is None: - self.prepare() - size_bytes = ((size_bytes + 127) // 128) * 128 - size_half = size_bytes // 2 - scratch_slice = self.scratch.narrow(0, 0, size_half) - return scratch_slice - - -def quant_post_init(model, max_tokens: Optional[int] = None): - """ - The max_tokens argument is specific to the exllama backend, - that requires to initialize a buffer temp_state. - """ - fixed_bytes = {} - - model_uses_exllama = False - for _, submodule in model.named_modules(): - if hasattr(submodule, "linear_weights") and getattr( - submodule.linear_method, "use_exllama", False): - model_uses_exllama = True - device = submodule.linear_weights["qweight"].device - height, width = submodule.linear_weights["qweight"].shape - scratch_fixed = submodule.linear_method.scratch_space_fixed( - height * submodule.linear_method.quant_config.pack_factor, - width, max_tokens) - fixed_bytes[device] = max(scratch_fixed, - fixed_bytes.get(device, 0)) - - if model_uses_exllama: - device_tensors = {} - for device, scratch_bytes in fixed_bytes.items(): - device_tensors[device] = ExLlamaV2DeviceTensors( - device.index, scratch_bytes) - - # have persistent buffers, otherwise we will get OOM - model.device_tensors = device_tensors - - for _, submodule in model.named_modules(): - if hasattr(submodule, "linear_weights") and getattr( - submodule.linear_method, "use_exllama", False): - device = submodule.qweight.device - submodule.linear_method.post_init(submodule.linear_weights, - temp_dq=model.device_tensors[device]) - torch.cuda.empty_cache() - - return model diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ce77e13e2969..71a22c7771b2 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -10,7 +10,6 @@ from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) -from vllm.model_executor.layers.quantization.utils import quant_post_init # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { @@ -59,7 +58,7 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module: +def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. @@ -99,6 +98,4 @@ def get_model(model_config: ModelConfig, max_tokens: int) -> nn.Module: model.load_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision) model = model.cuda() - if model_config.quantization is not None: - quant_post_init(model, max_tokens) return model.eval() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 5d8dbbad0ea6..bbbc2e7f45a6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -67,8 +67,7 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) - self.model = get_model(self.model_config, - self.scheduler_config.max_num_batched_tokens) + self.model = get_model(self.model_config) @torch.inference_mode() def profile_num_available_blocks( From b6b8c63cacfc218a24080d6566763d3d7a547b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Sun, 10 Dec 2023 18:06:40 +0800 Subject: [PATCH 12/18] Update GPTQ kernel and fix minor problems --- benchmarks/benchmark_latency.py | 2 +- benchmarks/benchmark_throughput.py | 2 +- csrc/ops.h | 32 +- csrc/pybind.cpp | 14 +- csrc/quantization/gptq/compat.cuh | 8 + csrc/quantization/gptq/exllama_ext.cpp | 100 --- csrc/quantization/gptq/matrix_view.cuh | 30 + csrc/quantization/gptq/old_matmul_kernel.cu | 125 --- csrc/quantization/gptq/q_gemm.cu | 789 ++++++++++++++++-- csrc/quantization/gptq/q_gemm.cuh | 33 - csrc/quantization/gptq/q_gemm_kernel_gptq.cuh | 217 ----- csrc/quantization/gptq/q_matrix.cu | 338 -------- csrc/quantization/gptq/q_matrix.cuh | 54 -- csrc/quantization/gptq/qdq_4.cuh | 13 + csrc/quantization/gptq/qdq_util.cuh | 9 + setup.py | 3 - vllm/config.py | 2 +- vllm/entrypoints/llm.py | 6 +- vllm/model_executor/layers/linear.py | 46 +- .../layers/quantization/__init__.py | 6 +- .../model_executor/layers/quantization/awq.py | 25 +- .../layers/quantization/gptq.py | 128 ++- .../layers/quantization/squeezellm.py | 17 +- vllm/model_executor/models/aquila.py | 6 +- vllm/model_executor/models/baichuan.py | 6 +- vllm/model_executor/models/chatglm.py | 3 +- vllm/model_executor/models/falcon.py | 13 +- vllm/model_executor/models/gpt_j.py | 6 +- vllm/model_executor/models/internlm.py | 6 +- vllm/model_executor/models/llama.py | 6 +- vllm/model_executor/models/mistral.py | 6 +- vllm/model_executor/models/mpt.py | 3 +- vllm/model_executor/models/opt.py | 6 +- vllm/model_executor/models/phi_1_5.py | 5 +- vllm/model_executor/models/qwen.py | 6 +- vllm/model_executor/models/yi.py | 6 +- vllm/model_executor/weight_utils.py | 1 - 37 files changed, 973 insertions(+), 1105 deletions(-) delete mode 100644 csrc/quantization/gptq/exllama_ext.cpp delete mode 100644 csrc/quantization/gptq/old_matmul_kernel.cu delete mode 100644 csrc/quantization/gptq/q_gemm.cuh delete mode 100644 csrc/quantization/gptq/q_gemm_kernel_gptq.cuh delete mode 100644 csrc/quantization/gptq/q_matrix.cu delete mode 100644 csrc/quantization/gptq/q_matrix.cuh diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 5fdf730896ec..b8720343244e 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -76,7 +76,7 @@ def run_to_completion(profile: bool = False): parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', 'gptq', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c47b4e4a9d24..5d17cb3f19ae 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -244,7 +244,7 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', 'gptq', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", diff --git a/csrc/ops.h b/csrc/ops.h index 3a2beb6d1602..a57f6002021d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -74,25 +74,17 @@ void squeezellm_gemm( torch::Tensor mul, torch::Tensor lookup_table); -uintptr_t make_q_matrix( - torch::Tensor q_weight, - torch::Tensor q_perm, - torch::Tensor q_invperm, - torch::Tensor gptq_qzeros, - torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx); - -void gemm_half_q_half( +torch::Tensor gptq_gemm +( torch::Tensor a, - uintptr_t b, - torch::Tensor c, - torch::Tensor temp_dq, - bool force_cuda); + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama +); -void gptq_descact_matmul( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor scales, - torch::Tensor zeros, - torch::Tensor g_idx); +void gptq_shuffle( + torch::Tensor q_weight, + torch::Tensor q_perm +); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 772ed46a21ac..3b140a49e4b9 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -50,19 +50,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); + ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def( - "make_q_matrix", - &make_q_matrix, - "Post processing for GPTQ"); - ops.def( - "gemm_half_q_half", - &gemm_half_q_half, - "Quantized GEMM for GPTQ"); - ops.def( - "gptq_descact_matmul", - &gptq_descact_matmul, - "Quantized GEMM for GPTQ for parallelized desc_act layer"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/quantization/gptq/compat.cuh index 12684ff8b59f..4da0bc6e2df3 100644 --- a/csrc/quantization/gptq/compat.cuh +++ b/csrc/quantization/gptq/compat.cuh @@ -1,6 +1,12 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + #ifndef _compat_cuh #define _compat_cuh +namespace vllm { +namespace gptq { // atomicAdd for half types, to support CC < 7.x __device__ __forceinline__ void atomicAdd_half(half* address, half val) @@ -53,4 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd #endif #endif +} // namespace gptq +} // namespace vllm #endif diff --git a/csrc/quantization/gptq/exllama_ext.cpp b/csrc/quantization/gptq/exllama_ext.cpp deleted file mode 100644 index ac74f6d1d64b..000000000000 --- a/csrc/quantization/gptq/exllama_ext.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "q_matrix.cuh" -#include "q_gemm.cuh" - -// Some decluttering macros - -#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype) -#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") -#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") - - -// Quant matrix - -uintptr_t make_q_matrix -( - torch::Tensor q_weight, - torch::Tensor q_perm, - torch::Tensor q_invperm, - torch::Tensor gptq_qzeros, - torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx -) -{ - TORCH_CHECK_DTYPE(q_weight, kInt); - TORCH_CHECK_DTYPE_OPT(q_perm, kShort); - TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); - TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); - TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); - TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); - TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8); - TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1); - - TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); - - int device = q_weight.device().index(); - int width = q_weight.size(1); - int groups; - int height; - - groups = gptq_qzeros.size(0); - height = q_weight.size(0) * 8; - - QMatrix* m = new QMatrix - ( - device, - height, - width, - groups, - (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), - q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), - gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), - gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), - gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr() - ); - - return reinterpret_cast (m); -} - -void gemm_half_q_half -( - torch::Tensor a, - uintptr_t b, - torch::Tensor c, - torch::Tensor temp_dq, - bool force_cuda -) -{ - QMatrix* qm = reinterpret_cast (b); - - TORCH_CHECK_DTYPE(a, kHalf); - TORCH_CHECK_DTYPE(c, kHalf); - TORCH_CHECK_SHAPES(a, 0, c, 0, 1); - TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes") - TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes") - - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - - gemm_half_q_half_cuda - ( - at::cuda::getCurrentCUDABlasHandle(), - (const half*) a.data_ptr(), - qm, - (half*) c.data_ptr(), - c.size(0), // m - c.size(1), // n - a.size(1), // k - true, - (half*) temp_dq.data_ptr(), - force_cuda - ); -} diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh index d1264d896bfc..1fdf019b2902 100644 --- a/csrc/quantization/gptq/matrix_view.cuh +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -1,3 +1,7 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama +*/ + #ifndef _matrix_view_cuh #define _matrix_view_cuh @@ -6,6 +10,9 @@ #include "qdq_util.cuh" +namespace vllm { +namespace gptq { + class MatrixView_half { public: @@ -118,4 +125,27 @@ public: } }; +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +} // namespace gptq +} // namespace vllm #endif diff --git a/csrc/quantization/gptq/old_matmul_kernel.cu b/csrc/quantization/gptq/old_matmul_kernel.cu deleted file mode 100644 index a79bf4b3edc0..000000000000 --- a/csrc/quantization/gptq/old_matmul_kernel.cu +++ /dev/null @@ -1,125 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "compat.cuh" - -const int BLOCKWIDTH = 256; -const int BLOCKHEIGHT = 32; - -__device__ inline unsigned int as_unsigned(int i) { - return *reinterpret_cast(&i); -} - -__device__ inline int as_int(int i) { - return *reinterpret_cast(&i); -} - -template -__global__ void VecQuant4MatMulKernel( - const scalar_t* __restrict__ vec, - const int* __restrict__ mat, - scalar_t* __restrict__ mul, - const scalar_t* __restrict__ scales, - const int* __restrict__ zeros, - const int* __restrict__ g_idx, - int batch, - int vec_height, - int height, - int width, - int zero_width -) { - int h = BLOCKHEIGHT * blockIdx.x; - int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - int h_end = min(h + BLOCKHEIGHT, height); - - __shared__ scalar_t blockvec[BLOCKWIDTH]; - int i = width * h + w; - int g_h = h * 8; - int h_range = (h_end - h) * 8; - int k; - unsigned int g; - scalar_t w_tmp; - - - int z_w = w / 8; - int z_mod = (w % 8) * 4; - - float weight[BLOCKWIDTH]; - - if (w < width) { - for (k = 0; k < h_range; ++k) { - int k_w = (k / 8); - int k_bit = (k % 8) * 4; - - g = as_int(g_idx[g_h + k]); - scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); - w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); - weight[k] = scale * (w_tmp - zero); - } - } - - scalar_t res; - for (int b = 0; b < batch; ++b) { - res = 0; - - if (threadIdx.x < h_range) { - blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; - } - __syncthreads(); - if (w < width) { - for (k = 0; k < h_range; ++k){ - res += weight[k] * blockvec[k]; - } - atomicAdd(&mul[b * width + w], res); - } - __syncthreads(); - } -} - -void vecquant4matmul_cuda( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor scales, - torch::Tensor zeros, - torch::Tensor g_idx -) { - int batch = vec.size(0); - int vec_height = vec.size(1); - int height = mat.size(0); - int width = mat.size(1); - int zero_width = zeros.size(1); - - dim3 blocks( - (height + BLOCKHEIGHT - 1) / BLOCKHEIGHT, - (width + BLOCKWIDTH - 1) / BLOCKWIDTH - ); - dim3 threads(BLOCKWIDTH); - - AT_DISPATCH_FLOATING_TYPES( - vec.type(), "vecquant4matmul_cuda", ([&] { - VecQuant4MatMulKernel<<>>( - vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), - batch, vec_height, height, width, zero_width - ); - }) - ); -} - -void gptq_descact_matmul( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor scales, - torch::Tensor zeros, - torch::Tensor g_idx -) -{ - const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx); -} \ No newline at end of file diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 6f86718cf021..6d070c658f15 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -1,17 +1,32 @@ -#include "q_gemm.cuh" -#include "matrix_view.cuh" +/* +Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa +*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "compat.cuh" +#include "matrix_view.cuh" #include "qdq_4.cuh" +namespace vllm { +namespace gptq { + #define BLOCK_KN_SIZE 128 #define BLOCK_M_SIZE_MAX 8 #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) -#define CLEAR_N_SIZE 256 #define MAX_Q_GEMM_ROWS 50 +#define MAX_ALT_GEMM_ROWS 8 +#define THREADS_X 32 +#define THREADS_Y 32 #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) -#include "q_gemm_kernel_gptq.cuh" - #if defined(USE_ROCM) __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, hipblasOperation_t transA, @@ -41,16 +56,224 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t #define rocblas_hgemm __compat_hipblasHgemm #endif +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int* +); + +template +__global__ void gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) +{ + #if BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel; + #endif + return NULL; +} + + void gemm_half_q_half_cuda_part ( const half* a, - QMatrix* b, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, half* c, int size_m, int size_n, int size_k, int m_count, - bool clear + int groups ) { dim3 blockDim, gridDim; @@ -66,44 +289,391 @@ void gemm_half_q_half_cuda_part kernel<<>> ( a, - b->cuda_q_weight, - b->cuda_gptq_qzeros, - b->cuda_gptq_scales, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, c, size_m, size_n, size_k, - b->groups, - b->groupsize, - b->cuda_q_perm, - clear + groups, + b_q_perm + ); +} + + +__global__ void reconstruct_exllama_kernel +( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + + +void reconstruct_exllama +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* out, + int height, + int width, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + reconstruct_exllama_kernel<<>> + ( + b_q_weight, + b_q_perm, + b_gptq_qzeros, + b_gptq_scales, + height, + width, + groups, + out + ); +} + + +__global__ void gemm_half_q_half_alt_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width +) +{ + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0xF), __int2half_rn(val >> 4) + ); + } + + if (blockIdx.z == 0) + { + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) + ); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { + res2 = {}; + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + + +void gemm_half_q_half_alt +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + int size_m, + int size_n, + int size_k +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + gemm_half_q_half_alt_kernel<<>> + ( + (const half2*) a, + b_q_weight, + c, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + size_m, + size_k / 8, + size_n + ); +} + + +__global__ void reconstruct_gptq_kernel +( + const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, + const int width, + const int group, + half* __restrict__ out +) +{ + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q4_row w_zeros_(w_zeros, group, width); + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + int group = g_idx[row + s / 4]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + + +void reconstruct_gptq +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* out, + int height, + int width, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 8); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_gptq_kernel<<>> + ( + b_q_weight, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + height, + width, + groups, + out ); } + void gemm_half_q_half_cuda ( cublasHandle_t cublas_handle, const half* a, - QMatrix* b, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, half* c, + half* temp_dq, int size_m, int size_n, int size_k, - bool clear, - half* temp_dq, - bool force_cuda + int groups, + bool use_exllama ) { - if (size_m > MAX_Q_GEMM_ROWS && !force_cuda) - { - + if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) { // Reconstruct FP16 matrix, then cuBLAS - b->reconstruct(temp_dq); - - //cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH); + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups); + } + else + { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups); + } const half alpha = __float2half(1.0f); - const half beta = clear ? __float2half(0.0f) : __float2half(1.0f); + const half beta = __float2half(0.0f); cublasHgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, @@ -111,56 +681,179 @@ void gemm_half_q_half_cuda &alpha, temp_dq, size_n, a, size_k, &beta, c, size_n); - } - else + else if (use_exllama) { // Quantized matmul - - //if (clear) clear_tensor_cuda(c, size_m, size_n); - int max_chunks = size_m / BLOCK_M_SIZE_MAX; int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; int last_chunk_size = size_m - last_chunk; if (max_chunks) { - gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear); + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, + groups); } if (last_chunk_size) { - gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear); + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, + b_gptq_scales, b_g_idx, c + last_chunk * size_n, + last_chunk_size, size_n, size_k, last_chunk_size, + groups); } } + else + { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k); + } } -__global__ void clear_kernel + +__global__ void shuffle_kernel ( - half* __restrict__ c, - const int size_m, + uint32_t* __restrict__ b_q_weight, + const int size_k, const int size_n ) { - int m = blockIdx.y; - int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8; + int n = blockIdx.x * THREADS_X + threadIdx.x; if (n >= size_n) return; - int4* c_ptr = (int4*)(c + m * size_n + n); - *c_ptr = {}; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } } -void clear_tensor_cuda + +__global__ void make_sequential_kernel ( - half* c, - int size_m, - int size_n + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_height, + const int w_width ) { - return; + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + + +void shuffle_exllama_weight +( + uint32_t* q_weight, + int* q_perm, + int height, + int width +) +{ + if (q_perm) + { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t)); + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 8; + + make_sequential_kernel<<>> + ( + q_weight, + new_qweight, + q_perm, + height / 8, + width + ); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } dim3 blockDim, gridDim; - blockDim.x = CLEAR_N_SIZE; + blockDim.x = THREADS_X; blockDim.y = 1; - gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE); - gridDim.y = size_m; - clear_kernel<<>>(c, size_m, size_n); + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + shuffle_kernel<<>>(q_weight, height, width); +} + +} // namespace gptq +} // namespace vllm + +torch::Tensor gptq_gemm +( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options); + + vllm::gptq::gemm_half_q_half_cuda + ( + at::cuda::getCurrentCUDABlasHandle(), + (const half*) a.data_ptr(), + (const uint32_t*) b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*) b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), + (half*) c.data_ptr(), + (half*) temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_exllama + ); + return c; +} + +void gptq_shuffle +( + torch::Tensor q_weight, + torch::Tensor q_perm +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + vllm::gptq::shuffle_exllama_weight( + (uint32_t*) q_weight.data_ptr(), + q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_weight.size(0) * 8, + q_weight.size(1) + ); } diff --git a/csrc/quantization/gptq/q_gemm.cuh b/csrc/quantization/gptq/q_gemm.cuh deleted file mode 100644 index c69f1a709689..000000000000 --- a/csrc/quantization/gptq/q_gemm.cuh +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef _q_gemm_cuh -#define _q_gemm_cuh - -#include -#include -#include -#include -#include - -#include "q_matrix.cuh" - -void gemm_half_q_half_cuda -( - cublasHandle_t cublas_handle, - const half* a, - QMatrix* b, - half* c, - int size_m, - int size_n, - int size_k, - bool clear = false, - half* reconstruct = NULL, - bool force_cuda = false -); - -void clear_tensor_cuda -( - half* c, - int size_m, - int size_n -); - -#endif \ No newline at end of file diff --git a/csrc/quantization/gptq/q_gemm_kernel_gptq.cuh b/csrc/quantization/gptq/q_gemm_kernel_gptq.cuh deleted file mode 100644 index 29c86e9555da..000000000000 --- a/csrc/quantization/gptq/q_gemm_kernel_gptq.cuh +++ /dev/null @@ -1,217 +0,0 @@ -#include "compat.cuh" - -__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __hadd2(result, g_result); -} - -__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) -{ - half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; - #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); - return __half2float(__low2half(result)) + __half2float(__high2half(result)); -} - -typedef void (*fp_gemm_half_q_half_gptq_kernel) -( - const half*, - const uint32_t*, - const uint32_t*, - const half*, - half*, - const int, - const int, - const int, - const int, - const int, - const uint16_t*, - const bool -); - -template -__global__ void gemm_half_q_half_gptq_kernel -( - const half* __restrict__ a, - const uint32_t* __restrict__ b_q_weight, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - half* __restrict__ c, - const int size_m, - const int size_n, - const int size_k, - const int groups, - const int groupsize, - const uint16_t* __restrict__ b_q_perm, - const bool clear -) -{ - MatrixView_half a_(a, size_m, size_k); - MatrixView_half_rw c_(c, size_m, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int t = threadIdx.x; - - // Block - - int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; - int offset_m = blockIdx.y * m_count; - int offset_k = blockIdx.z * BLOCK_KN_SIZE; - - int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); - int end_m = min(offset_m + m_count, size_m); - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - int n = offset_n + t * 4; - - // Preload block_a - - __shared__ half block_a[m_count][BLOCK_KN_SIZE]; - - if (offset_k + t < end_k) - { - for (int m = 0; m < m_count; ++m) - { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; - - half a0; - if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; - else a0 = a_ptr[offset_k + t]; - block_a_ptr[t] = a0; - } - } - - // Zero output - - if (n >= size_n) return; - - if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0) - { - for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; - } - - __syncthreads(); - - // Find initial group - - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // a, b offset - - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; - int a_stride = BLOCK_KN_SIZE; - - // Initial group - - int zeros[4]; - float scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - -// __syncthreads(); - - // Column result - - float block_c[m_count][4] = {}; - - // Dequantize and multiply - - int k = offset_k; - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_f(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - } - - #pragma unroll - for (int j = 0; j < 4; j++) - { - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - half2 dq[4][4]; - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - #pragma unroll - for (int m = 0; m < m_count; m++) - { - block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); - block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); - block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); - block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); - } - - b_ptr += size_n; - a_ptr += 8; - } - - k += 32; - } - - for (int m = 0; m < m_count; m++) - { - half2 *out = (half2*) c_.item_ptr(offset_m + m, n); - half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); - half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); - atomicAdd(out , result01); - atomicAdd(out + 1, result23); - } -} - -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) -{ - #if BLOCK_M_SIZE_MAX >= 1 - if (m_count == 1) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 2 - if (m_count == 2) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 3 - if (m_count == 3) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 4 - if (m_count == 4) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 5 - if (m_count == 5) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 6 - if (m_count == 6) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 7 - if (m_count == 7) return gemm_half_q_half_gptq_kernel; - #endif - #if BLOCK_M_SIZE_MAX >= 8 - if (m_count == 8) return gemm_half_q_half_gptq_kernel; - #endif - return NULL; -} diff --git a/csrc/quantization/gptq/q_matrix.cu b/csrc/quantization/gptq/q_matrix.cu deleted file mode 100644 index 23e40309dfca..000000000000 --- a/csrc/quantization/gptq/q_matrix.cu +++ /dev/null @@ -1,338 +0,0 @@ -#include "q_matrix.cuh" -#include "matrix_view.cuh" - -#include "qdq_4.cuh" - -#define BLOCK_KN_SIZE 128 - -#define THREADS_X 32 -#define THREADS_Y 32 -#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) - -// Shuffle quantized data on load - -__global__ void shuffle_kernel -( - uint32_t* __restrict__ b_q_weight, - const int size_k, - const int size_n -) -{ - int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; - int k = 0; - uint32_t* b_ptr = b_q_weight + n; - while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } -} - - -// QMatrix constructor - -QMatrix::QMatrix -( - const int _device, - const int _height, - const int _width, - const int _groups, - - uint32_t* _q_weight, - uint16_t* _q_perm, - uint16_t* _q_invperm, - - uint32_t* _gptq_qzeros, - half* _gptq_scales, - uint32_t* _gptq_g_idx -) : - device(_device), - height(_height), - width(_width), - groups(_groups) -{ - cudaSetDevice(device); - - cuda_q_weight = _q_weight; - cuda_q_perm = _q_perm; - cuda_q_invperm = _q_invperm; - cuda_gptq_qzeros = _gptq_qzeros; - cuda_gptq_scales = _gptq_scales; - - is_gptq = true; - - groupsize = 1; - while (groupsize * groups < height) groupsize *= 2; - - if (_gptq_g_idx) make_sequential(_gptq_g_idx); - - // Shuffle quantized data - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = 1; - - shuffle_kernel<<>>(cuda_q_weight, height, width); -} - - -// Reconstruct b[k,n] (GPTQ) - -__global__ void reconstruct_gptq_kernel -( - const uint32_t* __restrict__ b_q_weight, - const uint16_t* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, - const int size_k, - const int size_n, - const int groupsize, - const int groups, - half* __restrict__ b -) -{ - MatrixView_half_rw b_(b, size_k, size_n); - MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); - MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); - - int offset_k = BLOCK_KN_SIZE * blockIdx.y; - int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; - - int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); - - // Preload remapping table - - __shared__ uint16_t perm[BLOCK_KN_SIZE]; - int t = threadIdx.x; - - if (b_q_perm) - { - if (offset_k + t < size_k) - perm[t] = b_q_perm[offset_k + t]; - } - - // Column - - int n = offset_n + t * 4; - if (n >= size_n) return; - - // Find initial group - - int group = offset_k / groupsize; - int nextgroup = offset_k + groupsize; - - // b offset - - int qk = offset_k / (32 / 4); - - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - - // Initial zeros/scale - - int zeros[4]; - half2 scales[4]; - half2 z1z16[4][2]; - half2 y1y16[4][2]; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - - __syncthreads(); - - int k = offset_k; - int lk = 0; - - while (k < end_k) - { - if (k == nextgroup) - { - group++; - nextgroup += groupsize; - b_gptq_qzeros_.item4(zeros, group, n); - b_gptq_scales_.item4_h2(scales, group, n); - dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); - dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); - dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); - dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); - } - - for (int p = 0; p < 4; p++) - { - half2 dq[4][4]; - const int4* b_ptr4 = (int4*) b_ptr; - int4 load_int4 = *b_ptr4; - - dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); - dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); - dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); - dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); - - b_ptr += size_n; - //half* dqh = (half*)dq; - if (b_q_perm) - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - else - { - for (int j = 0; j < 4; j++) - { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); - b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); - b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); - } - } - } - k += 32; - } -} - -void QMatrix::reconstruct(half* out) -{ - dim3 blockDim, gridDim; - blockDim.x = BLOCK_KN_SIZE; - blockDim.y = 1; - gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); - - reconstruct_gptq_kernel<<>> - ( - cuda_q_weight, - cuda_q_perm, - cuda_gptq_qzeros, - cuda_gptq_scales, - height, - width, - groupsize, - groups, - out - ); -} - -__global__ void make_sequential_kernel -( - const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const uint16_t* __restrict__ q_perm, - const int w_height, - const int w_width -) -{ - const uint64_t* w2 = (uint64_t*) w; - uint64_t* w_new2 = (uint64_t*) w_new; - int w2_stride = w_width >> 1; - - int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; - - int w_new2_row = blockIdx.y; - - int q_perm_idx = w_new2_row << 3; - - uint64_t dst = 0; - - #pragma unroll - for (int i = 0; i < 8; i++) - { - int source_row = q_perm[q_perm_idx++]; - - int w2_row = source_row >> 3; - int w2_subrow = source_row & 0x07; - int w2_row_shift = w2_subrow << 2; - int wnew2_row_shift = i << 2; - - uint64_t src = w2[w2_row * w2_stride + w2_column]; - src >>= w2_row_shift; - src &= 0x0000000f0000000f; - src <<= wnew2_row_shift; - dst |= src; - } - - w_new2[w_new2_row * w2_stride + w2_column] = dst; -} - -void QMatrix::make_sequential(const uint32_t* cpu_g_idx) -{ - uint32_t* cuda_new_qweight = NULL; - cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t)); - - uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); - - // Group histogram - - for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++; - - // Group map - - for (int i = 0, acc = 0; i < groups; i++) - { - short tmp = cpu_g_idx_map[i]; - cpu_g_idx_map[i] = acc; - acc += tmp; - } - - // X map (inverse) - - for (int row = 0; row < height; row++) - { - uint32_t target_group = cpu_g_idx[row]; - uint32_t target_row = cpu_g_idx_map[target_group]; - cpu_g_idx_map[target_group]++; - cpu_x_map_inv[row] = target_row; - } - - // X map - - for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row; - - // Reduce to uint16_t - - uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map; - uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv; - for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row]; - for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row]; - - // Move to CUDA - - cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); - cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice); - - // Rearrange rows in w - - dim3 blockDim, gridDim; - blockDim.x = THREADS_X; - blockDim.y = 1; - gridDim.x = DIVIDE(width, THREADS_X); - gridDim.y = height / 8; - - make_sequential_kernel<<>> - ( - cuda_q_weight, - cuda_new_qweight, - cuda_q_perm, - height / 8, - width - ); - - // Replace qweights - - cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); - - // Cleanup - - cudaDeviceSynchronize(); - - cudaFree(cuda_new_qweight); - free(cpu_g_idx_map); - free(cpu_x_map); - free(cpu_x_map_inv); -} diff --git a/csrc/quantization/gptq/q_matrix.cuh b/csrc/quantization/gptq/q_matrix.cuh deleted file mode 100644 index 3fedbc0823f9..000000000000 --- a/csrc/quantization/gptq/q_matrix.cuh +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef _q_matrix_cuh -#define _q_matrix_cuh - -#include -#include -#include -#include - -#define MAX_SUPERGROUPS 16 - -class QMatrix -{ -public: - - int device; - bool is_gptq; - - int height; - int width; - int groups; - int groupsize; - - uint32_t* cuda_q_weight = NULL; - uint16_t* cuda_q_perm = NULL; - uint16_t* cuda_q_invperm = NULL; - uint32_t* cuda_gptq_qzeros = NULL; - half* cuda_gptq_scales = NULL; - - QMatrix - ( - const int _device, - const int _height, - const int _width, - const int _groups, - - uint32_t* _q_weight, - uint16_t* _q_perm, - uint16_t* _q_invperm, - - uint32_t* _gptq_qzeros, - half* _gptq_scales, - uint32_t* _gptq_g_idx - ); - - ~QMatrix(); - - void reconstruct(half* out); - void make_sequential(const uint32_t* cpu_g_idx); - -private: - -}; - -#endif diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh index a7bde6d30508..cfc4635a22c1 100644 --- a/csrc/quantization/gptq/qdq_4.cuh +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -1,8 +1,14 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + #ifndef _qdq_4_cuh #define _qdq_4_cuh #include "qdq_util.cuh" +namespace vllm { +namespace gptq { // Permutation: // // 77775555 33331111 66664444 22220000 @@ -134,9 +140,13 @@ __forceinline__ __device__ void dequant_4bit_8_gptq dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) } } +} // namespace gptq +} // namespace vllm #else +namespace vllm { +namespace gptq { __forceinline__ __device__ void shuffle_4bit_8 ( uint32_t* q, @@ -219,4 +229,7 @@ __forceinline__ __device__ void dequant_4bit_8_gptq } } +} // namespace gptq +} // namespace vllm + #endif diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh index 71657191b911..1722a9aa6cb3 100644 --- a/csrc/quantization/gptq/qdq_util.cuh +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -1,6 +1,13 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + #ifndef _qdq_util_cuh #define _qdq_util_cuh +namespace vllm { +namespace gptq { + union half2_uint32 { uint32_t as_uint32; @@ -48,4 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i return (int)(__funnelshift_rc(q0, q1, shift) & mask); } +} // namespace gptq +} // namespace vllm #endif diff --git a/setup.py b/setup.py index a705d81adb22..412896050fa2 100644 --- a/setup.py +++ b/setup.py @@ -152,10 +152,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/layernorm_kernels.cu", "csrc/quantization/awq/gemm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", - "csrc/quantization/gptq/exllama_ext.cpp", - "csrc/quantization/gptq/q_matrix.cu", "csrc/quantization/gptq/q_gemm.cu", - "csrc/quantization/gptq/old_matmul_kernel.cu", "csrc/cuda_utils_kernels.cu", "csrc/pybind.cpp", ], diff --git a/vllm/config.py b/vllm/config.py index 280150d83b92..805e426dc4d3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -115,7 +115,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "squeezellm", "gptq"] + supported_quantization = ["awq", "gptq", "squeezellm"] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d15d7706f5c5..75edd49b0a52 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -38,9 +38,9 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq" and "gptq". If None, we assume the model weights - are not quantized and use `dtype` to determine the data type of the - weights. + we support "awq", "gptq" and "squeezellm". If None, we assume the + model weights are not quantized and use `dtype` to determine the + data type of the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6aed7357ae9b..a472de18907a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -21,11 +21,10 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - parallel_type: str = "none") -> Dict[str, torch.Tensor]: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + utput_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" raise NotImplementedError @@ -49,13 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, - input_size: int, + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, output_size: int, - params_dtype: torch.dtype, - parallel_type: str = "none") -> Dict[str, torch.Tensor]: - weight = Parameter(torch.empty(output_size, - input_size, + params_dtype: torch.dtype) -> Dict[str, Any]: + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) @@ -108,9 +106,11 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.params_dtype) + self.input_size, self.output_size, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) if bias: self.bias = Parameter( torch.empty(self.output_size, @@ -174,11 +174,12 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.params_dtype, - "column") + self.input_size, self.output_size_per_partition, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -488,11 +489,12 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.params_dtype, - "row") + self.input_size_per_partition, self.output_size, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index f0f3b5e10270..b3449eaff0e3 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,14 +1,14 @@ from typing import Type +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, - "squeezellm": SqueezeLLMConfig, "gptq": GPTQConfig, + "squeezellm": SqueezeLLMConfig, } diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 9e3614fd9aac..831576b1d7cd 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -77,17 +77,16 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, - input_size: int, + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, output_size: int, - params_dtype: torch.dtype, - parallel_type: str = "none") -> Dict[str, torch.Tensor]: - if input_size % self.quant_config.group_size != 0: + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") - if output_size % self.quant_config.pack_factor != 0: + if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " @@ -95,8 +94,8 @@ def create_weights(self, qweight = Parameter( torch.empty( - input_size, - output_size // self.quant_config.pack_factor, + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -111,8 +110,8 @@ def create_weights(self, }) qzeros = Parameter( torch.empty( - input_size // self.quant_config.group_size, - output_size // self.quant_config.pack_factor, + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -127,8 +126,8 @@ def create_weights(self, }) scales = Parameter( torch.empty( - input_size // self.quant_config.group_size, - output_size, + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, device="cuda", dtype=params_dtype, ), @@ -145,7 +144,7 @@ def create_weights(self, } def apply_weights(self, - weights: Dict[str, torch.Tensor], + weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 0c0b7b0b276f..7f03f6a6990e 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict, List, Optional import torch @@ -7,8 +8,6 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) class GPTQConfig(QuantizationConfig): @@ -71,6 +70,9 @@ def get_scaled_act_names(self) -> List[str]: return [] +ExlState = Enum('ExlState', ['Unused', 'Uninitialized', 'Ready']) + + class GPTQLinearMethod(LinearMethodBase): """Linear method for GPTQ. @@ -80,26 +82,41 @@ class GPTQLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config - self.use_exllama = True - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype, - parallel_type: str = "none") -> Dict[str, torch.Tensor]: - if input_size % self.quant_config.group_size != 0: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") - if output_size % self.quant_config.pack_factor != 0: + if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + exllama_state = ExlState.Uninitialized + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if input_size != input_size_per_partition and self.quant_config.group_size != -1: + # For act-order models, we cannot use Exllama for row parallel layer + if self.quant_config.desc_act: + exllama_state = ExlState.Unused + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 qweight = Parameter( torch.empty( - input_size // self.quant_config.pack_factor, - output_size, + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, device="cuda", dtype=torch.int32, ), @@ -114,29 +131,20 @@ def create_weights(self, input_size: int, output_size: int, }) g_idx = Parameter( torch.tensor( - [i // self.quant_config.group_size for i in range(input_size)], + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], device="cuda", dtype=torch.int32, ), requires_grad=False, ) set_weight_attrs(g_idx, {"input_dim": 0}) - tp_size = get_tensor_model_parallel_world_size() - if parallel_type == "row" and tp_size > 1 and (self.quant_config.desc_act - and self.quant_config.group_size != -1): - input_size = input_size * tp_size - use_exllama = Parameter(torch.tensor(False, dtype=torch.bool, device="cuda"), requires_grad=False) - else: - use_exllama = Parameter(torch.tensor(True, dtype=torch.bool, device="cuda"), requires_grad=False) - if self.quant_config.desc_act or self.quant_config.group_size == -1: - input_dim = None - else: - input_dim = 0 - group_size = self.quant_config.group_size if self.quant_config.group_size != -1 else input_size qzeros = Parameter( torch.empty( - input_size // group_size, - output_size // self.quant_config.pack_factor, + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -144,22 +152,22 @@ def create_weights(self, input_size: int, output_size: int, ) set_weight_attrs( qzeros, { - "input_dim": input_dim, + "input_dim": scale_and_zero_input_dim, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, }) scales = Parameter( torch.empty( - input_size // group_size, - output_size, + scale_and_zero_size, + output_size_per_partition, device="cuda", dtype=params_dtype, ), requires_grad=False, ) set_weight_attrs(scales, { - "input_dim": input_dim, + "input_dim": scale_and_zero_input_dim, "output_dim": 1, }) return { @@ -167,60 +175,30 @@ def create_weights(self, input_size: int, output_size: int, "g_idx": g_idx, "qzeros": qzeros, "scales": scales, - "use_exllama": use_exllama, + "exllama_state": exllama_state, } def apply_weights(self, - weights: Dict[str, torch.Tensor], + weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] - height, width = weights["qweight"].shape out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) - if weights["use_exllama"]: - if "q4" not in weights: - if not self.quant_config.desc_act: - none_tensor = torch.empty((1, 1), device="meta") - weights["q4"] = ops.make_q_matrix( - weights["qweight"], - none_tensor, - none_tensor, - weights["qzeros"], - weights["scales"], - none_tensor, - ) - else: - weights["q_perm"] = torch.empty( - (height * self.quant_config.pack_factor, ), - dtype=torch.short, - device=weights["qweight"].device) - weights["q_invperm"] = torch.empty_like(weights["q_perm"]) - weights["q4"] = ops.make_q_matrix( - weights["qweight"], - weights["q_perm"], - weights["q_invperm"], - weights["qzeros"], - weights["scales"], - weights["g_idx"].cpu(), - ) - temp_dq = torch.empty((height * self.quant_config.pack_factor, width), - dtype=torch.float16, - device=x.device) - output = torch.empty((reshaped_x.shape[0], qweight.shape[-1]), - dtype=torch.float16, - device=x.device) - ops.gemm_half_q_half(reshaped_x, weights["q4"], output, - temp_dq, False) - else: - output = torch.zeros((reshaped_x.shape[0], qweight.shape[-1]), - dtype=torch.float32, - device=x.device) - ops.gptq_descact_matmul(reshaped_x.float(), - weights["qweight"], output, - weights["scales"].float(), - weights["qzeros"], weights["g_idx"]) - output = output.half() + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if weights["exllama_state"] == ExlState.Uninitialized: + if self.quant_config.desc_act: + weights["g_idx"] = torch.argsort(weights["g_idx"]).to( + torch.int) + else: + weights["g_idx"] = torch.empty((1, 1), device="meta") + weights["exllama_state"] = ExlState.Ready + ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) + output = ops.gptq_gemm(reshaped_x, weights["qweight"], + weights["qzeros"], weights["scales"], + weights["g_idx"], + weights["exllama_state"] == ExlState.Ready) if bias is not None: output = output + bias return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 6018b13dff92..1a2d07d41742 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -66,20 +66,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, - input_size: int, + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, output_size: int, - params_dtype: torch.dtype, - parallel_type: str = "none") -> Dict[str, torch.Tensor]: - if input_size % self.quant_config.pack_factor != 0: + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") qweight = Parameter( torch.empty( - input_size // self.quant_config.pack_factor, - output_size, + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, device="cuda", dtype=torch.int32, ), @@ -94,7 +93,7 @@ def create_weights(self, }) lookup_table = Parameter( torch.empty( - output_size, + output_size_per_partition, self.quant_config.weight_bits**2, device="cuda", dtype=params_dtype, @@ -110,7 +109,7 @@ def create_weights(self, } def apply_weights(self, - weights: Dict[str, torch.Tensor], + weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 278b506a0961..90a4edd47f75 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -333,14 +333,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 4a90a9d9b2ce..57952e7f89d3 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -356,14 +356,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e13e68d7b67b..a778e5521fdb 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -377,7 +377,8 @@ def load_weights(self, continue if "word_embeddings" in name: name = name.replace(".word_embeddings", "") - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 28ab2a93a619..34e71de0d232 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -377,6 +377,10 @@ def __init__( self.config = config self.linear_method = linear_method self.transformer = FalconModel(config, linear_method) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) self.sampler = Sampler(config.vocab_size) def forward( @@ -401,8 +405,8 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> SamplerOutput: - next_tokens = self.sampler(self.transformer.word_embeddings.weight, - hidden_states, sampling_metadata) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) return next_tokens def load_weights(self, @@ -421,7 +425,8 @@ def load_weights(self, params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index e2c324f7c631..49dfbcd0d1bb 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -275,14 +275,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 6b2d23c90f3e..cdefa552c178 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -287,14 +287,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 7d9eb1021519..4371aadd90b8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -326,14 +326,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index ba3a9baaec00..576d967457ba 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -322,14 +322,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 6c54d83862b6..02fe980dc243 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -279,7 +279,8 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 9bb4583fa4de..3f9b33857314 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -346,14 +346,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 23202424c027..0c9671e68311 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -305,9 +305,10 @@ def load_weights(self, if "rotary_emb.inv_freq" in name: continue - # pylint: disable=E1136 - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue + # pylint: disable=E1136 param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 5b306deb6fc7..f41a3cc4385a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -279,14 +279,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 3c6f2a36cc37..d16f21c0d702 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -321,14 +321,16 @@ def load_weights(self, if weight_name not in name: continue name = name.replace(weight_name, param_name) - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - if name not in params_dict: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index ad2bc0518d8f..36ad0f389e2c 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -289,4 +289,3 @@ def initialize_dummy_weights( for param in model.state_dict().values(): if torch.is_floating_point(param): param.data.uniform_(low, high) - From d1954ab2561ea683b7d27d2f903c7bf4ce640bcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A5=9A=E5=A4=A9=E7=BF=94?= Date: Mon, 11 Dec 2023 13:29:11 +0800 Subject: [PATCH 13/18] Fix typo --- vllm/model_executor/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a472de18907a..1fd7cf4b2591 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -23,7 +23,7 @@ class LinearMethodBase(ABC): @abstractmethod def create_weights(self, input_size_per_partition: int, output_size_per_partition: int, input_size: int, - utput_size: int, + output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" raise NotImplementedError From 62d6760667d839f215d2616b0d3fe58f0dccfce4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 15 Dec 2023 09:29:14 +0000 Subject: [PATCH 14/18] Minor fix --- .../layers/quantization/gptq.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 2b3101619b3d..18dc229af58c 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -1,3 +1,4 @@ +import enum from enum import Enum from typing import Any, Dict, List, Optional @@ -7,7 +8,8 @@ from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) class GPTQConfig(QuantizationConfig): @@ -52,9 +54,7 @@ def get_min_capability(cls) -> int: @classmethod def get_config_filenames(cls) -> List[str]: - return [ - "quantize_config.json", - ] + return ["quantize_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": @@ -70,7 +70,11 @@ def get_scaled_act_names(self) -> List[str]: return [] -ExlState = Enum('ExlState', ['Unused', 'Uninitialized', 'Ready']) +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() class GPTQLinearMethod(LinearMethodBase): @@ -83,10 +87,15 @@ class GPTQLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQConfig): self.quant_config = quant_config - def create_weights(self, input_size_per_partition: int, - output_size_per_partition: int, input_size: int, - output_size: int, - params_dtype: torch.dtype) -> Dict[str, Any]: + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + del output_size # Unused. if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -97,17 +106,18 @@ def create_weights(self, input_size_per_partition: int, "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") + if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - exllama_state = ExlState.Uninitialized + exllama_state = ExllamaState.UNINITIALIZED scale_and_zero_size = input_size // group_size scale_and_zero_input_dim = None if input_size != input_size_per_partition and self.quant_config.group_size != -1: # For act-order models, we cannot use Exllama for row parallel layer if self.quant_config.desc_act: - exllama_state = ExlState.Unused + exllama_state = ExllamaState.UNUSED else: # we need to partition qzeros and scales for exllama kernel scale_and_zero_size = input_size_per_partition // group_size @@ -187,18 +197,18 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass - if weights["exllama_state"] == ExlState.Uninitialized: + if weights["exllama_state"] == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: weights["g_idx"] = torch.argsort(weights["g_idx"]).to( torch.int) else: weights["g_idx"] = torch.empty((1, 1), device="meta") - weights["exllama_state"] = ExlState.Ready + weights["exllama_state"] = ExllamaState.READY ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) output = ops.gptq_gemm(reshaped_x, weights["qweight"], weights["qzeros"], weights["scales"], weights["g_idx"], - weights["exllama_state"] == ExlState.Ready) + weights["exllama_state"] == ExllamaState.READY) if bias is not None: output = output + bias return output.reshape(out_shape) From 5156579697ef6b6f652db91f3c78ffe2181ca998 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 15 Dec 2023 09:31:59 +0000 Subject: [PATCH 15/18] Minor --- csrc/ops.h | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 577084203616..9340a60da141 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -78,17 +78,14 @@ void squeezellm_gemm( torch::Tensor mul, torch::Tensor lookup_table); -torch::Tensor gptq_gemm -( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama -); +torch::Tensor gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama); void gptq_shuffle( - torch::Tensor q_weight, - torch::Tensor q_perm -); + torch::Tensor q_weight, + torch::Tensor q_perm); From 1f3f6eee6196c51d9a95428bb9dfb63ce6ee5d26 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 15 Dec 2023 10:15:26 +0000 Subject: [PATCH 16/18] Support Mixtral --- vllm/model_executor/models/mixtral.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index b11e3713fd4d..a3f1582a3473 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -153,7 +153,7 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, - linear_method=linear_method) + linear_method=None) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -418,11 +418,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 99cc231fdb857fc23dcc1b42b8068a95895c22c3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 15 Dec 2023 10:48:07 +0000 Subject: [PATCH 17/18] Ignore warning --- vllm/model_executor/layers/linear.py | 20 +++++++++++-------- .../layers/quantization/gptq.py | 3 ++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1fd7cf4b2591..5190de65d795 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -303,10 +303,12 @@ def weight_loader(self, loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) else: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions.") + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -426,10 +428,12 @@ def weight_loader(self, loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) else: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "QKVParallelLinear, assume the weight is the same " - "for all partitions.") + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 18dc229af58c..8fe96e7ddb98 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -150,7 +150,8 @@ def create_weights( ), requires_grad=False, ) - set_weight_attrs(g_idx, {"input_dim": 0}) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) qzeros = Parameter( torch.empty( scale_and_zero_size, From 17fcdd24b9240ac08ea15f02f6361565664fe44e Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 15 Dec 2023 10:58:01 +0000 Subject: [PATCH 18/18] Fix squeezellm --- vllm/model_executor/layers/quantization/squeezellm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index a153f044117f..1932bd145076 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -94,7 +94,7 @@ def create_weights(self, input_size_per_partition: int, }) lookup_table = Parameter( torch.empty( - output_size_per_partition, + output_size, self.quant_config.weight_bits**2, device="cuda", dtype=params_dtype,