Skip to content

Commit cf198e1

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Improving quantized matmul performance by devectorizing shader. (pytorch#15274)
Summary: This diff improves the performance of quantized matrix multiplication by devectorizing the shader. An example modification is shown below: ```glsl // Before VEC4_T sums[TILE_ROWS][TILE_TXCOLS]; // After T sums[TILE_ROWS * TILE_TXCOLS * 4]; // Before sums[r][${c}] = VEC4_T(0.0); // After for (int j = 0; j < 4; j++) { sums[r * TILE_TXCOLS * 4 + ${c} * 4 + j] = T(0.0); } ``` Reviewed By: SS-JIA Differential Revision: D85023829
1 parent 9397b81 commit cf198e1

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,12 @@ void main() {
6262
return;
6363
}
6464

65-
VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
65+
T sums[TILE_ROWS * TILE_TXCOLS * 4];
6666

6767
for (int r = 0; r < TILE_ROWS; ++r) {
6868
$for c in range(TILE_TXCOLS):
69-
sums[r][${c}] = VEC4_T(0.0);
69+
$for j in range(4):
70+
sums[r * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] = T(0.0);
7071
}
7172

7273
const int in_row_txstride = div4(in_sizes.x);
@@ -75,22 +76,16 @@ void main() {
7576
txpos < in_row_txstride;
7677
pos += 4, txpos += 1) {
7778

78-
T mat1[TILE_ROWS][4];
79+
T mat1[TILE_ROWS * 4];
7980

8081
// Preload input tensor
8182
for (int i = 0; i < TILE_ROWS; i++) {
8283
$if IN_STORAGE == "buffer":
83-
VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
84-
mat1[i][0] = tmp.x;
85-
mat1[i][1] = tmp.y;
86-
mat1[i][2] = tmp.z;
87-
mat1[i][3] = tmp.w;
84+
VEC4_T mat1_vec4 = t_in[(out_row + i) * in_row_txstride + txpos];
8885
$else:
89-
VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
90-
mat1[i][0] = tmp.x;
91-
mat1[i][1] = tmp.y;
92-
mat1[i][2] = tmp.z;
93-
mat1[i][3] = tmp.w;
86+
VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
87+
$for j in range(4):
88+
mat1[i * 4 + ${j}] = mat1_vec4[${j}];
9489
}
9590

9691
$if WEIGHT_STORAGE == "buffer":
@@ -99,7 +94,9 @@ void main() {
9994

10095
// Preload weight tensor
10196
for (int r = 0; r < 4; r++) {
102-
VEC4_T qmat2[TILE_TXCOLS];
97+
T qmat2[TILE_TXCOLS * 4];
98+
VEC4_T qmat2_vec4;
99+
103100
$if QUANT_NBITS == 4:
104101
$if WEIGHT_STORAGE == "buffer":
105102
u8vec4 packed_weight_tex;
@@ -114,20 +111,31 @@ void main() {
114111
packed_weight_tex = texelFetch(
115112
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
116113

117-
qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0);
118-
qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
114+
qmat2_vec4 = (VEC4_T(packed_weight_tex >> 4) - 8.0);
115+
qmat2[${c} * 4 * TILE_TXCOLS + 0] = qmat2_vec4.x;
116+
qmat2[${c} * 4 * TILE_TXCOLS + 1] = qmat2_vec4.y;
117+
qmat2[${c} * 4 * TILE_TXCOLS + 2] = qmat2_vec4.z;
118+
qmat2[${c} * 4 * TILE_TXCOLS + 3] = qmat2_vec4.w;
119+
120+
qmat2_vec4 = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
121+
qmat2[${c} * 4 * TILE_TXCOLS + 4] = qmat2_vec4.x;
122+
qmat2[${c} * 4 * TILE_TXCOLS + 5] = qmat2_vec4.y;
123+
qmat2[${c} * 4 * TILE_TXCOLS + 6] = qmat2_vec4.z;
124+
qmat2[${c} * 4 * TILE_TXCOLS + 7] = qmat2_vec4.w;
119125
$else:
120126
$for c in range(TILE_TXCOLS):
121127
$if WEIGHT_STORAGE == "buffer":
122128
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
123-
qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
129+
qmat2_vec4 = t_weight[qmat2_bufi + ${c}];
124130
$else:
125-
qmat2[${c}] = VEC4_T(
126-
texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
131+
qmat2_vec4 = VEC4_T(texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
132+
$for j in range(4):
133+
qmat2[${c} * 4 + ${j}] = qmat2_vec4[${j}];
127134

128135
for (int tr = 0; tr < TILE_ROWS; ++tr) {
129136
$for c in range(TILE_TXCOLS):
130-
sums[tr][${c}] += qmat2[${c}] * mat1[tr][r];
137+
$for j in range(4):
138+
sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r];
131139
}
132140
}
133141
}
@@ -146,16 +154,22 @@ void main() {
146154
uint out_row_txstride = div4(out_sizes.x);
147155

148156
for (int r = 0; r < TILE_ROWS; ++r) {
157+
VEC4_T scaled_sums;
149158
$for c in range(TILE_TXCOLS):
159+
scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0] * scales[${c}].x;
160+
scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1] * scales[${c}].y;
161+
scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2] * scales[${c}].z;
162+
scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3] * scales[${c}].w;
163+
150164
$if OUT_STORAGE == "buffer":
151165
if (out_row + r < out_sizes.y) {
152166
out_bufi = (out_row + r) * out_row_txstride + out_txcol;
153-
t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}];
167+
t_out[out_bufi + ${c}] = scaled_sums;
154168
}
155169
$else:
156170
imageStore(
157171
t_out,
158172
ivec3(out_txcol + ${c}, out_row + r, 0),
159-
sums[r][${c}] * scales[${c}]);
173+
scaled_sums);
160174
}
161175
}

0 commit comments

Comments
 (0)