Skip to content

Commit 5076749

Browse files
authored
Fixing issues with using buffers for quantized linear weights.
Differential Revision: D84870681 Pull Request resolved: pytorch#15306
1 parent afe1dda commit 5076749

File tree

5 files changed

+47
-22
lines changed

5 files changed

+47
-22
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818

1919
${define_required_extensions(DTYPE)}
2020

21-
$if WEIGHT_STORAGE == "buffer":
22-
${define_required_extensions("int8")}
23-
2421
layout(std430) buffer;
2522

2623
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
2724
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
28-
$if QUANT_NBITS == 4:
29-
${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
25+
$if WEIGHT_STORAGE == "buffer":
26+
${layout_declare_tensor(B, "r", "t_weight", "uint", WEIGHT_STORAGE, is_scalar_array=True)}
3027
$else:
31-
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
28+
$if QUANT_NBITS == 4:
29+
${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
30+
$else:
31+
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
3232
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}
3333

3434

@@ -91,22 +91,20 @@ void main() {
9191
$if WEIGHT_STORAGE == "buffer":
9292
uint qmat2_bufi;
9393
uint weight_row_txstride = div4(weight_sizes.x);
94+
uint encoded_weight;
9495

9596
// Preload weight tensor
9697
for (int r = 0; r < 4; r++) {
9798
T qmat2[TILE_TXCOLS * 4];
9899
VEC4_T qmat2_vec4;
100+
uvec4 packed_weight_tex;
99101

100102
$if QUANT_NBITS == 4:
101-
$if WEIGHT_STORAGE == "buffer":
102-
u8vec4 packed_weight_tex;
103-
$else:
104-
uvec4 packed_weight_tex;
105-
106103
$for c in range(0, TILE_TXCOLS, 2):
107104
$if WEIGHT_STORAGE == "buffer":
108105
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
109-
packed_weight_tex = t_weight[qmat2_bufi + ${c}]
106+
encoded_weight = t_weight[qmat2_bufi + ${c}];
107+
packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
110108
$else:
111109
packed_weight_tex = texelFetch(
112110
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
@@ -126,7 +124,9 @@ void main() {
126124
$for c in range(TILE_TXCOLS):
127125
$if WEIGHT_STORAGE == "buffer":
128126
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
129-
qmat2_vec4 = t_weight[qmat2_bufi + ${c}];
127+
encoded_weight = t_weight[qmat2_bufi + ${c}];
128+
packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
129+
qmat2_vec4 = VEC4_T(packed_weight_tex);
130130
$else:
131131
qmat2_vec4 = VEC4_T(texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
132132
$for j in range(4):

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,18 @@ linear_qcsnw_tiled:
3535
- NAME: linear_qcs4w_tiled_texture3d_texture3d_texture2d_texture2d_float
3636
TILE_TXCOLS: 2
3737
QUANT_NBITS: 4
38+
- NAME: linear_qcs4w_tiled_texture3d_texture3d_buffer_texture2d_float
39+
TILE_TXCOLS: 2
40+
QUANT_NBITS: 4
41+
WEIGHT_STORAGE: buffer
3842
- NAME: linear_qcs4w_tiled_buffer_buffer_texture2d_texture2d_float
3943
IN_STORAGE: buffer
4044
OUT_STORAGE: buffer
4145
TILE_TXCOLS: 2
4246
QUANT_NBITS: 4
47+
- NAME: linear_qcs4w_tiled_buffer_buffer_buffer_texture2d_float
48+
IN_STORAGE: buffer
49+
OUT_STORAGE: buffer
50+
WEIGHT_STORAGE: buffer
51+
TILE_TXCOLS: 2
52+
QUANT_NBITS: 4

backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212

1313
$if not NO_INT8_BUFFERS:
1414
${define_required_extensions("uint8")}
15-
$if STORAGE == "buffer":
16-
${define_required_extensions("int8")}
1715

1816
layout(std430) buffer;
1917

20-
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
18+
$if STORAGE == "buffer" and NO_INT8_BUFFERS:
19+
${layout_declare_tensor(B, "w", "t_qmat2", "uint", STORAGE, is_scalar_array=True)}
20+
$else:
21+
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
22+
2123
$if NO_INT8_BUFFERS:
2224
${layout_declare_tensor(B, "r", "nchw_4x2", "uint", "buffer")}
2325
$else:
@@ -35,7 +37,10 @@ $else:
3537
#define BUF_T uint8_t
3638

3739
$if STORAGE == "buffer":
38-
#define UVEC4_T u8vec4
40+
$if NO_INT8_BUFFERS:
41+
#define UVEC4_T uvec4
42+
$else:
43+
#define UVEC4_T u8vec4
3944
$else:
4045
#define UVEC4_T uvec4
4146

@@ -48,7 +53,7 @@ uint get_second(const BUF_T packed) {
4853
}
4954

5055
uint combine(const uint first, const uint second) {
51-
return (first << 4 | second);
56+
return first * 16 + second;
5257
}
5358

5459
$if NO_INT8_BUFFERS:
@@ -155,8 +160,12 @@ void main() {
155160

156161
$if STORAGE == "buffer":
157162
int stride = qmat2_sizes.x >> 2;
158-
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
159-
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
163+
$if NO_INT8_BUFFERS:
164+
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1.x | (out_tex_1.y << 8) | (out_tex_1.z << 16) | (out_tex_1.w << 24);
165+
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2.x | (out_tex_2.y << 8) | (out_tex_2.z << 16) | (out_tex_2.w << 24);
166+
$else:
167+
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
168+
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
160169
$else:
161170
imageStore(t_qmat2, packed_pos.xy, out_tex_1);
162171
imageStore(t_qmat2, ivec2(packed_pos.x, packed_pos.y + 1), out_tex_2);

backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ pack_int4_linear_weight_transposed_interleaved:
1414
STORAGE: buffer
1515
- NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_texture2d
1616
NO_INT8_BUFFERS: true
17+
- NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_buffer
18+
STORAGE: buffer
19+
NO_INT8_BUFFERS: true

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ void add_linear_qcs8w_node(
225225
} else {
226226
pcs = {
227227
graph.logical_limits_pc_of(out_W_packed),
228-
graph.sizes_pc_of(mat1_W_packed)};
228+
graph.sizes_pc_of(mat1_W_packed),
229+
graph.sizes_pc_of(q_mat2)};
229230
}
230231

231232
const utils::uvec3 global_wg = {
@@ -351,7 +352,9 @@ void add_linear_qcsnw_tiled_node(
351352
// Shader params buffers
352353
{},
353354
// Push Constants
354-
{{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}},
355+
{{graph.sizes_pc_of(out),
356+
graph.sizes_pc_of(mat1),
357+
graph.sizes_pc_of(q_mat2)}},
355358
// Specialization Constants
356359
{},
357360
// Resize Args

0 commit comments

Comments
 (0)