Skip to content

Commit 8d705d9

Browse files
HandH1998LeiWang1999
authored andcommitted
Support W4A8 quantization for vllm (vllm-project#5218)
Signed-off-by: LeiWang1999 <[email protected]>
1 parent ff4532c commit 8d705d9

File tree

15 files changed

+1963
-84
lines changed

15 files changed

+1963
-84
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
2+
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
3+
tasks:
4+
- name: "gsm8k"
5+
metrics:
6+
- name: "exact_match,strict-match"
7+
value: 0.409
8+
- name: "exact_match,flexible-extract"
9+
value: 0.406
10+
limit: 1000
11+
num_fewshot: 5

.buildkite/lm-eval-harness/configs/models-small.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
77
Minitron-4B-Base.yaml
88
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
99
Qwen2-1.5B-Instruct-FP8W8.yaml
10+
Meta-Llama-3-8B-QQQ.yaml

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
170170
"csrc/quantization/awq/gemm_kernels.cu"
171171
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
172172
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
173+
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
173174
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
174175
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
175176
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"

csrc/ops.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
115115
torch::Tensor const& b_scales,
116116
c10::optional<torch::Tensor> const& bias);
117117

118+
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
119+
torch::Tensor const& b_q_weight,
120+
torch::Tensor const& s_tok,
121+
torch::Tensor const& s_ch,
122+
torch::Tensor const& s_group,
123+
torch::Tensor& workspace, int64_t size_m,
124+
int64_t size_n, int64_t size_k);
118125
#endif
119126

120127
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Modified by HandH1998
3+
* Modified by Neural Magic
4+
* Copyright (C) Marlin.2024 Elias Frantar
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#pragma once
20+
21+
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
22+
23+
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
24+
// for instance as inputs to tensor core operations. Consequently, all
25+
// corresponding index accesses must be compile-time constants, which is why we
26+
// extensively use `#pragma unroll` throughout the kernel code to guarantee
27+
// this.
28+
template <typename T, int n>
29+
struct Vec {
30+
T elems[n];
31+
__device__ T& operator[](int i) { return elems[i]; }
32+
};
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Modified by HandH1998
3+
* Modified by Neural Magic
4+
* Copyright (C) Marlin.2024 Elias Frantar
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#pragma once
20+
21+
// Predicated asynchronous global->shared copy; used for inputs A where we apply
22+
// predication to handle batchsizes that are not multiples of 16.
23+
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
24+
bool pred = true) {
25+
const int BYTES = 16;
26+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
27+
asm volatile(
28+
"{\n"
29+
" .reg .pred p;\n"
30+
" setp.ne.b32 p, %0, 0;\n"
31+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
32+
"}\n" ::"r"((int)pred),
33+
"r"(smem), "l"(glob_ptr), "n"(BYTES));
34+
}
35+
36+
// Asynchronous global->shared copy
37+
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
38+
const int BYTES = 16;
39+
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
40+
asm volatile(
41+
"{\n"
42+
" cp.async.cg.shared.global [%0], [%1], %2;\n"
43+
"}\n" ::"r"(smem),
44+
"l"(glob_ptr), "n"(BYTES));
45+
}
46+
47+
// Async copy fence.
48+
__device__ inline void cp_async_fence() {
49+
asm volatile("cp.async.commit_group;\n" ::);
50+
}
51+
52+
// Wait until at most `n` async copy stages are still pending.
53+
template <int n>
54+
__device__ inline void cp_async_wait() {
55+
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
56+
}
57+
58+
// Wait until barrier reaches `count`, then lock for current threadblock.
59+
__device__ inline void barrier_acquire(int* lock, int count) {
60+
if (threadIdx.x == 0) {
61+
int state = -1;
62+
do
63+
// Guarantee that subsequent writes by this threadblock will be visible
64+
// globally.
65+
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
66+
: "=r"(state)
67+
: "l"(lock));
68+
while (state != count);
69+
}
70+
__syncthreads();
71+
}
72+
73+
// Release barrier and increment visitation count.
74+
__device__ inline void barrier_release(int* lock, bool reset = false) {
75+
__syncthreads();
76+
if (threadIdx.x == 0) {
77+
if (reset) {
78+
lock[0] = 0;
79+
return;
80+
}
81+
int val = 1;
82+
// Make sure that all writes since acquiring this barrier are visible
83+
// globally, while releasing the barrier.
84+
asm volatile("fence.acq_rel.gpu;\n");
85+
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
86+
:
87+
: "l"(lock), "r"(val));
88+
}
89+
}

csrc/quantization/marlin/dense/marlin_cuda_kernel.cu

Lines changed: 6 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,22 @@
2525

2626
#include <iostream>
2727

28+
#include "common/base.h"
29+
30+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
31+
#include "common/mem.h"
32+
#endif
33+
2834
template <typename T>
2935
inline std::string str(T x) {
3036
return std::to_string(x);
3137
}
3238

3339
namespace marlin_dense {
3440

35-
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
36-
3741
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
3842

39-
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
40-
// for instance as inputs to tensor core operations. Consequently, all
41-
// corresponding index accesses must be compile-time constants, which is why we
42-
// extensively use `#pragma unroll` throughout the kernel code to guarantee
43-
// this.
44-
template <typename T, int n>
45-
struct Vec {
46-
T elems[n];
47-
__device__ T& operator[](int i) { return elems[i]; }
48-
};
49-
5043
using I4 = Vec<int, 4>;
51-
5244
// Matrix fragments for tensor core instructions; their precise layout is
5345
// documented here:
5446
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
@@ -57,43 +49,6 @@ using FragB = Vec<half2, 2>;
5749
using FragC = Vec<float, 4>;
5850
using FragS = Vec<half2, 1>; // quantization scales
5951

60-
// Predicated asynchronous global->shared copy; used for inputs A where we apply
61-
// predication to handle batchsizes that are not multiples of 16.
62-
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
63-
bool pred = true) {
64-
const int BYTES = 16;
65-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
66-
asm volatile(
67-
"{\n"
68-
" .reg .pred p;\n"
69-
" setp.ne.b32 p, %0, 0;\n"
70-
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
71-
"}\n" ::"r"((int)pred),
72-
"r"(smem), "l"(glob_ptr), "n"(BYTES));
73-
}
74-
75-
// Asynchronous global->shared copy
76-
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
77-
const int BYTES = 16;
78-
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
79-
asm volatile(
80-
"{\n"
81-
" cp.async.cg.shared.global [%0], [%1], %2;\n"
82-
"}\n" ::"r"(smem),
83-
"l"(glob_ptr), "n"(BYTES));
84-
}
85-
86-
// Async copy fence.
87-
__device__ inline void cp_async_fence() {
88-
asm volatile("cp.async.commit_group;\n" ::);
89-
}
90-
91-
// Wait until at most `n` async copy stages are still pending.
92-
template <int n>
93-
__device__ inline void cp_async_wait() {
94-
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
95-
}
96-
9752
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
9853
// output/accumulation.
9954
__device__ inline void mma(const FragA& a_frag, const FragB& frag_b,
@@ -164,39 +119,6 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
164119
frag_b[1] = __hmul2(frag_b[1], s);
165120
}
166121

167-
// Wait until barrier reaches `count`, then lock for current threadblock.
168-
__device__ inline void barrier_acquire(int* lock, int count) {
169-
if (threadIdx.x == 0) {
170-
int state = -1;
171-
do
172-
// Guarantee that subsequent writes by this threadblock will be visible
173-
// globally.
174-
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
175-
: "=r"(state)
176-
: "l"(lock));
177-
while (state != count);
178-
}
179-
__syncthreads();
180-
}
181-
182-
// Release barrier and increment visitation count.
183-
__device__ inline void barrier_release(int* lock, bool reset = false) {
184-
__syncthreads();
185-
if (threadIdx.x == 0) {
186-
if (reset) {
187-
lock[0] = 0;
188-
return;
189-
}
190-
int val = 1;
191-
// Make sure that all writes since acquiring this barrier are visible
192-
// globally, while releasing the barrier.
193-
asm volatile("fence.acq_rel.gpu;\n");
194-
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
195-
:
196-
: "l"(lock), "r"(val));
197-
}
198-
}
199-
200122
template <const int threads, // number of threads in a threadblock
201123
const int thread_m_blocks, // number of 16x16 blocks in the m
202124
// dimension (batchsize) of the

0 commit comments

Comments
 (0)