Skip to content

Commit b81e0e4

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Converting all uint16 to int in quantized mat mul shader to improve perf. (#15193)
Summary: ## This Diff This diff improves the performance of the quantized matrix multiplication shader in the Executorch Vulkan runtime by converting all `uint16` to `int` in the shader code. Reviewed By: SS-JIA Differential Revision: D84777696
1 parent 77441de commit b81e0e4

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ ${define_required_extensions(DTYPE)}
2121
$if WEIGHT_STORAGE == "buffer":
2222
${define_required_extensions("int8")}
2323

24-
#extension GL_EXT_control_flow_attributes : require
25-
2624
layout(std430) buffer;
2725

2826
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
@@ -49,20 +47,18 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4947
void main() {
5048
// txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
5149
$if TILE_TXCOLS > 1:
52-
const uint16_t global_wg_x = uint16_t(divup(out_sizes.x, 4 * TILE_TXCOLS));
53-
const uint16_t out_txcol = uint16_t(
54-
(gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
50+
const int global_wg_x = divup(out_sizes.x, 4 * TILE_TXCOLS);
51+
const int out_txcol = (int(gl_GlobalInvocationID.x) % global_wg_x) * TILE_TXCOLS;
5552
$else:
56-
const uint16_t global_wg_x = uint16_t(divup4(out_sizes.x));
57-
const uint16_t out_txcol = uint16_t(gl_GlobalInvocationID.x % global_wg_x);
53+
const int global_wg_x = divup4(out_sizes.x);
54+
const int out_txcol = int(gl_GlobalInvocationID.x) % global_wg_x;
5855

59-
const uint16_t out_row = uint16_t(
60-
(gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);
56+
const int out_row = (int(gl_GlobalInvocationID.x) / global_wg_x) * TILE_ROWS;
6157

6258
$if QUANT_NBITS == 4:
63-
const uint16_t weight_txcol = uint16_t(out_txcol / 2);
59+
const int weight_txcol = out_txcol / 2;
6460

65-
if (out_row >= uint16_t(out_sizes.y)) {
61+
if (out_row >= int(out_sizes.y)) {
6662
return;
6763
}
6864

@@ -73,9 +69,9 @@ void main() {
7369
sums[r][${c}] = VEC4_T(0.0);
7470
}
7571

76-
for (uint16_t pos = uint16_t(0), txpos = uint16_t(0);
77-
pos < uint16_t(in_sizes.x);
78-
pos += uint16_t(4), txpos += uint16_t(1)) {
72+
for (int pos = 0, txpos = 0;
73+
pos < in_sizes.x;
74+
pos += 4, txpos += 1) {
7975

8076
T mat1[TILE_ROWS][4];
8177

@@ -91,7 +87,7 @@ void main() {
9187
mat1[i][2] = tmp.z;
9288
mat1[i][3] = tmp.w;
9389
$else:
94-
VEC4_T tmp = VEC4_T(texelFetch(t_in, u16vec3(txpos, out_row + i, 0), 0));
90+
VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
9591
mat1[i][0] = tmp.x;
9692
mat1[i][1] = tmp.y;
9793
mat1[i][2] = tmp.z;
@@ -117,7 +113,7 @@ void main() {
117113
packed_weight_tex = t_weight[qmat2_bufi + ${c}]
118114
$else:
119115
packed_weight_tex = texelFetch(
120-
t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0);
116+
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
121117

122118
qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0);
123119
qmat2[${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
@@ -128,7 +124,7 @@ void main() {
128124
qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
129125
$else:
130126
qmat2[${c}] = VEC4_T(
131-
texelFetch(t_weight, u16vec2(out_txcol + ${c}, pos + r), 0));
127+
texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
132128

133129
for (int tr = 0; tr < TILE_ROWS; ++tr) {
134130
$for c in range(TILE_TXCOLS):
@@ -143,7 +139,7 @@ void main() {
143139
scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
144140
$else:
145141
scales[${c}] = VEC4_T(
146-
texelFetch(t_scales, u16vec2(out_txcol + ${c}, 0), 0));
142+
texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0));
147143

148144
// Store to output tensor
149145
$if OUT_STORAGE == "buffer":

0 commit comments

Comments
 (0)