Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
82e6b2e
Add gptq implementation compatible with awq interface
chu-tianxiang Sep 18, 2023
612d7b1
Add more models
chu-tianxiang Sep 25, 2023
049a37c
fix bug in model loading
chu-tianxiang Sep 25, 2023
5563578
Add fallback kernel for desc act models
chu-tianxiang Sep 27, 2023
0470121
Fix engine args and opt model
chu-tianxiang Sep 27, 2023
92c7f8d
Merge main branch
chu-tianxiang Oct 8, 2023
f9d0ccc
Add mistral model
chu-tianxiang Oct 9, 2023
cbf9433
Fix bug in gpt layer
chu-tianxiang Oct 11, 2023
a7b391d
Fix conflict
chu-tianxiang Oct 24, 2023
b51ebb7
Merge main branch
chu-tianxiang Oct 24, 2023
9a99461
Fix squeezellm
chu-tianxiang Oct 24, 2023
2593dfe
Use exllama v2 kernels for better performance
chu-tianxiang Nov 2, 2023
97072a7
Add Yi and ChatGLM GPTQ support
chu-tianxiang Nov 14, 2023
2d8dc1d
Fix chatglm
chu-tianxiang Nov 14, 2023
22ea9ce
merge main
chu-tianxiang Dec 1, 2023
17b6f2b
Fix phi model
chu-tianxiang Dec 1, 2023
62bd8ce
move post init to first forward pass to make code cleaner
chu-tianxiang Dec 2, 2023
e1c4c25
merge main
chu-tianxiang Dec 3, 2023
b6b8c63
Update GPTQ kernel and fix minor problems
chu-tianxiang Dec 10, 2023
1bcb832
Merge main
chu-tianxiang Dec 10, 2023
d1954ab
Fix typo
chu-tianxiang Dec 11, 2023
514021c
Merge branch 'main' into gptq_hf
WoosukKwon Dec 15, 2023
62d6760
Minor fix
WoosukKwon Dec 15, 2023
5156579
Minor
WoosukKwon Dec 15, 2023
1f3f6ee
Support Mixtral
WoosukKwon Dec 15, 2023
99cc231
Ignore warning
WoosukKwon Dec 15, 2023
17fcdd2
Fix squeezellm
WoosukKwon Dec 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion csrc/quantization.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <cstdint>
#include <torch/extension.h>

torch::Tensor awq_gemm(
Expand All @@ -7,9 +8,59 @@ torch::Tensor awq_gemm(
torch::Tensor _zeros,
int split_k_iters);

void gptq_set_tuning_params(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2);

void gptq_prepare_buffers(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq);

uintptr_t gptq_make_q4(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device);

void gptq_q4_matmul(
torch::Tensor x,
uintptr_t w,
torch::Tensor out);

void gptq_descact_matmul(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
}
m.def(
"gptq_set_tuning_params",
&gptq_set_tuning_params,
"Set tuning params for GPTQ");
m.def(
"gptq_prepare_buffers",
&gptq_prepare_buffers,
"Prepare buffers for GPTQ");
m.def(
"gptq_make_q4",
&gptq_make_q4,
"Preprocess weight for GPTQ");
m.def(
"gptq_q4_matmul",
&gptq_q4_matmul,
"Quantized GEMM for GPTQ");
m.def(
"gptq_descact_matmul",
&gptq_descact_matmul,
"Quantized GEMM for GPTQ for parallelized desc_act layer");
}
58 changes: 58 additions & 0 deletions csrc/quantization/gptq/cu_compat.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _cuda_compat_cuh
#define _cuda_compat_cuh

// atomicAdd for half types, to support CC < 7.x

__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;

do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}

// atomicAdd for half2 types

__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}

//

#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)

__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }

#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif

#endif
#endif

#endif
75 changes: 75 additions & 0 deletions csrc/quantization/gptq/cuda_buffers.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#define _cuda_buffers_cu
#include "cuda_buffers.cuh"

CudaBuffers* g_buffers[CUDA_MAX_DEVICES] = {NULL};
// __constant__ half2 q4_table[16][256];
// half2 q4_table_host[16][256];
// bool q4_table_init = false;

CudaBuffers::CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
) :
device(_device),
temp_state_size(_temp_state_size),
temp_state(_temp_state),
temp_dq(_temp_dq)
{
cudaSetDevice(_device);

cudaStreamCreate(&alt_stream_1);
cudaStreamCreate(&alt_stream_2);
cudaStreamCreate(&alt_stream_3);
cudaEventCreate(&alt_stream_1_done);
cudaEventCreate(&alt_stream_2_done);
cudaEventCreate(&alt_stream_3_done);
}

CudaBuffers::~CudaBuffers()
{
cudaStreamDestroy(alt_stream_1);
cudaStreamDestroy(alt_stream_2);
cudaStreamDestroy(alt_stream_3);
cudaEventDestroy(alt_stream_1_done);
cudaEventDestroy(alt_stream_2_done);
cudaEventDestroy(alt_stream_3_done);
}

CudaBuffers* get_buffers(const int device_index)
{
return g_buffers[device_index];
}

void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
)
{
CudaBuffers* buffers = new CudaBuffers
(
_device,
_temp_state_size,
_temp_state,
_temp_dq
);

g_buffers[_device] = buffers;
}

void cleanup_buffers_cuda()
{
for (int i = 0; i < CUDA_MAX_DEVICES; i++)
{
if (!g_buffers[i]) continue;
delete g_buffers[i];
g_buffers[i] = NULL;
}
}
55 changes: 55 additions & 0 deletions csrc/quantization/gptq/cuda_buffers.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _cuda_buffers_cuh
#define _cuda_buffers_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>

const int CUDA_MAX_DEVICES = 16;

// #ifndef _cuda_buffers_cu
// extern __constant__ half2 q4_table[16][256];
// #endif

class CudaBuffers
{
public:
int device;

half* temp_state; // [max_hidden_rows * intermediate_size]
int temp_state_size;
half* temp_dq; // size of largest quant tensor * 8

cudaStream_t alt_stream_1;
cudaStream_t alt_stream_2;
cudaStream_t alt_stream_3;
cudaEvent_t alt_stream_1_done;
cudaEvent_t alt_stream_2_done;
cudaEvent_t alt_stream_3_done;

CudaBuffers
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);
~CudaBuffers();
};

CudaBuffers* get_buffers(const int device_index);

void prepare_buffers_cuda
(
int _device,
int _temp_state_size,
half* _temp_state,
half* _temp_dq
);

void cleanup_buffers_cuda();

#endif
63 changes: 63 additions & 0 deletions csrc/quantization/gptq/cuda_func/column_remap.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#include "column_remap.cuh"
#include "../util.cuh"

const int SHUF_BLOCKSIZE_X = 256;
const int SHUF_BLOCKSIZE_Y = 16;

__global__ void column_remap_kernel
(
const half* __restrict__ x,
half* __restrict__ x_new,
const int x_width,
const int x_height,
const uint32_t* x_map
)
{
int x_column = SHUF_BLOCKSIZE_X * blockIdx.x + threadIdx.x;
int x_row = SHUF_BLOCKSIZE_Y * blockIdx.y;
if (x_column >= x_width) return;
//if (x_row >= x_height) return;

int x_stride = x_width;
int x_idx = x_row * x_stride + x_column;

int x_row_end = min(x_row + SHUF_BLOCKSIZE_Y, x_height);
int x_idx_end = x_row_end * x_stride + x_column;

int s_column = x_map[x_column];
int s_idx = x_row * x_stride + s_column;

while (x_idx < x_idx_end)
{
x_new[x_idx] = x[s_idx];
x_idx += x_stride;
s_idx += x_stride;
}
}

// Remap columns in x to correspond to sequential group index before matmul
//
// perform x -> seq_x such that seq_x @ seq_w == x @ w

void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
)
{
dim3 threads(SHUF_BLOCKSIZE_X, 1, 1);

dim3 blocks
(
(x_width + SHUF_BLOCKSIZE_X - 1) / SHUF_BLOCKSIZE_X,
(x_height + SHUF_BLOCKSIZE_Y - 1) / SHUF_BLOCKSIZE_Y,
1
);

column_remap_kernel<<<blocks, threads>>>(x, x_new, x_width, x_height, x_map);
}
19 changes: 19 additions & 0 deletions csrc/quantization/gptq/cuda_func/column_remap.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Adapted from turboderp exllama: https://github.com/turboderp/exllama

#ifndef _column_remap_cuh
#define _column_remap_cuh

#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>

void column_remap_cuda
(
const half* x,
half* x_new,
const int x_height,
const int x_width,
const uint32_t* x_map
);

#endif
Loading