Skip to content

Commit 2311ae0

Browse files
rasmithgarg-amit
authored andcommitted
[Kernel] [Triton] Memory optimization for awq_gemm and awq_dequantize, 2x throughput (vllm-project#8248)
Signed-off-by: Amit Garg <[email protected]>
1 parent 1447c97 commit 2311ae0

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

vllm/model_executor/layers/quantization/awq_triton.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
2222

2323
# Compute offsets and masks for qweight_ptr.
2424
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
25-
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
25+
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
2626
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
2727

2828
masks_y = offsets_y < num_rows
@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
4343

4444
# Load the weights.
4545
iweights = tl.load(qweight_ptr + offsets, masks)
46+
iweights = tl.interleave(iweights, iweights)
47+
iweights = tl.interleave(iweights, iweights)
48+
iweights = tl.interleave(iweights, iweights)
4649

4750
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
4851
# that will map given indices to the correct order.
@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
5962
iweights = (iweights >> shifts) & 0xF
6063

6164
# Compute zero offsets and masks.
62-
zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
63-
tl.arange(0, BLOCK_SIZE_Y) // group_size)
64-
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
65+
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
66+
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
6567
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
6668

6769
zero_masks_y = zero_offsets_y < num_rows // group_size
@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
7072

7173
# Load the zeros.
7274
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
75+
zeros = tl.interleave(zeros, zeros)
76+
zeros = tl.interleave(zeros, zeros)
77+
zeros = tl.interleave(zeros, zeros)
78+
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
7379

7480
# Unpack and reorder: shift out the correct 4-bit value and mask.
7581
zeros = (zeros >> shifts) & 0xF
7682

7783
# Compute scale offsets and masks.
78-
scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
79-
tl.arange(0, BLOCK_SIZE_Y) // group_size)
84+
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
8085
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
8186
tl.arange(0, BLOCK_SIZE_X * 8))
8287
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
8792

8893
# Load the scales.
8994
scales = tl.load(scales_ptr + scale_offsets, scale_masks)
95+
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
9096

9197
# Dequantize.
9298
iweights = (iweights - zeros) * scales
@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
137143
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
138144
masks_am = offsets_am < M
139145

140-
offsets_bn = (pid_n * (BLOCK_SIZE_N // 8) +
141-
tl.arange(0, BLOCK_SIZE_N) // 8)
146+
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
142147
masks_bn = offsets_bn < N // 8
143148

144-
offsets_zn = (pid_n * (BLOCK_SIZE_N // 8) +
145-
tl.arange(0, BLOCK_SIZE_N) // 8)
149+
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
146150
masks_zn = offsets_zn < N // 8
147151

148152
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
165169

166170
masks_b = masks_k[:, None] & masks_bn[None, :]
167171
b = tl.load(b_ptrs, mask=masks_b)
172+
b = tl.interleave(b, b)
173+
b = tl.interleave(b, b)
174+
b = tl.interleave(b, b)
168175

169176
# Dequantize b.
170177
offsets_szk = (
171178
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size +
172-
tl.arange(0, BLOCK_SIZE_K) // group_size)
179+
tl.arange(0, 1))
173180
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
174181
masks_zk = offsets_szk < K // group_size
175182
masks_z = masks_zk[:, None] & masks_zn[None, :]
176183
zeros_ptrs = zeros_ptr + offsets_z
177184
zeros = tl.load(zeros_ptrs, mask=masks_z)
185+
zeros = tl.interleave(zeros, zeros)
186+
zeros = tl.interleave(zeros, zeros)
187+
zeros = tl.interleave(zeros, zeros)
188+
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
178189

179190
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
180191
masks_sk = offsets_szk < K // group_size
181192
masks_s = masks_sk[:, None] & masks_sn[None, :]
182193
scales_ptrs = scales_ptr + offsets_s
183194
scales = tl.load(scales_ptrs, mask=masks_s)
195+
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
184196

185197
b = (b >> shifts) & 0xF
186198
zeros = (zeros >> shifts) & 0xF

0 commit comments

Comments
 (0)