@@ -22,7 +22,7 @@ def awq_dequantize_kernel(
22
22
23
23
# Compute offsets and masks for qweight_ptr.
24
24
offsets_y = pid_y * BLOCK_SIZE_Y + tl .arange (0 , BLOCK_SIZE_Y )
25
- offsets_x = pid_x * BLOCK_SIZE_X + tl .arange (0 , BLOCK_SIZE_X * 8 ) // 8
25
+ offsets_x = pid_x * BLOCK_SIZE_X + tl .arange (0 , BLOCK_SIZE_X )
26
26
offsets = num_cols * offsets_y [:, None ] + offsets_x [None , :]
27
27
28
28
masks_y = offsets_y < num_rows
@@ -43,6 +43,9 @@ def awq_dequantize_kernel(
43
43
44
44
# Load the weights.
45
45
iweights = tl .load (qweight_ptr + offsets , masks )
46
+ iweights = tl .interleave (iweights , iweights )
47
+ iweights = tl .interleave (iweights , iweights )
48
+ iweights = tl .interleave (iweights , iweights )
46
49
47
50
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
48
51
# that will map given indices to the correct order.
@@ -59,9 +62,8 @@ def awq_dequantize_kernel(
59
62
iweights = (iweights >> shifts ) & 0xF
60
63
61
64
# Compute zero offsets and masks.
62
- zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
63
- tl .arange (0 , BLOCK_SIZE_Y ) // group_size )
64
- zero_offsets_x = pid_x * BLOCK_SIZE_X + tl .arange (0 , BLOCK_SIZE_X * 8 ) // 8
65
+ zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl .arange (0 , 1 )
66
+ zero_offsets_x = pid_x * BLOCK_SIZE_X + tl .arange (0 , BLOCK_SIZE_X )
65
67
zero_offsets = num_cols * zero_offsets_y [:, None ] + zero_offsets_x [None , :]
66
68
67
69
zero_masks_y = zero_offsets_y < num_rows // group_size
@@ -70,13 +72,16 @@ def awq_dequantize_kernel(
70
72
71
73
# Load the zeros.
72
74
zeros = tl .load (zeros_ptr + zero_offsets , zero_masks )
75
+ zeros = tl .interleave (zeros , zeros )
76
+ zeros = tl .interleave (zeros , zeros )
77
+ zeros = tl .interleave (zeros , zeros )
78
+ zeros = tl .broadcast_to (zeros , (BLOCK_SIZE_Y , BLOCK_SIZE_X * 8 ))
73
79
74
80
# Unpack and reorder: shift out the correct 4-bit value and mask.
75
81
zeros = (zeros >> shifts ) & 0xF
76
82
77
83
# Compute scale offsets and masks.
78
- scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size +
79
- tl .arange (0 , BLOCK_SIZE_Y ) // group_size )
84
+ scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl .arange (0 , 1 )
80
85
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 +
81
86
tl .arange (0 , BLOCK_SIZE_X * 8 ))
82
87
scale_offsets = (num_cols * 8 * scale_offsets_y [:, None ] +
@@ -87,6 +92,7 @@ def awq_dequantize_kernel(
87
92
88
93
# Load the scales.
89
94
scales = tl .load (scales_ptr + scale_offsets , scale_masks )
95
+ scales = tl .broadcast_to (scales , (BLOCK_SIZE_Y , BLOCK_SIZE_X * 8 ))
90
96
91
97
# Dequantize.
92
98
iweights = (iweights - zeros ) * scales
@@ -137,12 +143,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
137
143
offsets_am = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
138
144
masks_am = offsets_am < M
139
145
140
- offsets_bn = (pid_n * (BLOCK_SIZE_N // 8 ) +
141
- tl .arange (0 , BLOCK_SIZE_N ) // 8 )
146
+ offsets_bn = pid_n * (BLOCK_SIZE_N // 8 ) + tl .arange (0 , BLOCK_SIZE_N // 8 )
142
147
masks_bn = offsets_bn < N // 8
143
148
144
- offsets_zn = (pid_n * (BLOCK_SIZE_N // 8 ) +
145
- tl .arange (0 , BLOCK_SIZE_N ) // 8 )
149
+ offsets_zn = pid_n * (BLOCK_SIZE_N // 8 ) + tl .arange (0 , BLOCK_SIZE_N // 8 )
146
150
masks_zn = offsets_zn < N // 8
147
151
148
152
offsets_sn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
@@ -165,22 +169,30 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
165
169
166
170
masks_b = masks_k [:, None ] & masks_bn [None , :]
167
171
b = tl .load (b_ptrs , mask = masks_b )
172
+ b = tl .interleave (b , b )
173
+ b = tl .interleave (b , b )
174
+ b = tl .interleave (b , b )
168
175
169
176
# Dequantize b.
170
177
offsets_szk = (
171
178
(BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K ) // group_size +
172
- tl .arange (0 , BLOCK_SIZE_K ) // group_size )
179
+ tl .arange (0 , 1 ) )
173
180
offsets_z = (N // 8 ) * offsets_szk [:, None ] + offsets_zn [None , :]
174
181
masks_zk = offsets_szk < K // group_size
175
182
masks_z = masks_zk [:, None ] & masks_zn [None , :]
176
183
zeros_ptrs = zeros_ptr + offsets_z
177
184
zeros = tl .load (zeros_ptrs , mask = masks_z )
185
+ zeros = tl .interleave (zeros , zeros )
186
+ zeros = tl .interleave (zeros , zeros )
187
+ zeros = tl .interleave (zeros , zeros )
188
+ zeros = tl .broadcast_to (zeros , (BLOCK_SIZE_K , BLOCK_SIZE_N ))
178
189
179
190
offsets_s = N * offsets_szk [:, None ] + offsets_sn [None , :]
180
191
masks_sk = offsets_szk < K // group_size
181
192
masks_s = masks_sk [:, None ] & masks_sn [None , :]
182
193
scales_ptrs = scales_ptr + offsets_s
183
194
scales = tl .load (scales_ptrs , mask = masks_s )
195
+ scales = tl .broadcast_to (scales , (BLOCK_SIZE_K , BLOCK_SIZE_N ))
184
196
185
197
b = (b >> shifts ) & 0xF
186
198
zeros = (zeros >> shifts ) & 0xF
0 commit comments