Skip to content

Commit c3d692a

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Doubling tile texel col count for mat mul op to improve performance. (#15192)
Summary: ### Summary This change doubled tile texel column count for 8 bit matrix multiplication operation to improve performance. Differential Revision: D84679398
1 parent 4101c56 commit c3d692a

File tree

5 files changed

+18
-8
lines changed

5 files changed

+18
-8
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,17 @@ void main() {
9898
// Preload weight tensor
9999
[[unroll]] for (int r = 0; r < 4; r++) {
100100
$if QUANT_NBITS == 4:
101+
$if WEIGHT_STORAGE == "buffer":
102+
u8vec4 packed_weight_tex;
103+
$else:
104+
uvec4 packed_weight_tex;
105+
101106
$for c in range(0, TILE_TXCOLS, 2):
102107
$if WEIGHT_STORAGE == "buffer":
103108
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
104-
const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
109+
packed_weight_tex = t_weight[qmat2_bufi + ${c}]
105110
$else:
106-
const uvec4 packed_weight_tex = texelFetch(
111+
packed_weight_tex = texelFetch(
107112
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
108113

109114
qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0);

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ linear_qcsnw_coop:
1212
WEIGHT_STORAGE: texture2d
1313
SCALES_STORAGE: texture2d
1414
TILE_ROWS: 4
15-
TILE_TXCOLS: 1
15+
TILE_TXCOLS: 2
1616
QUANT_NBITS: 8
1717
generate_variant_forall:
1818
TILE_ROWS:

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,17 @@ void main() {
106106
for (int r = 0; r < 4; r++) {
107107
VEC4_T qmat2[TILE_TXCOLS];
108108
$if QUANT_NBITS == 4:
109+
$if WEIGHT_STORAGE == "buffer":
110+
u8vec4 packed_weight_tex;
111+
$else:
112+
uvec4 packed_weight_tex;
113+
109114
$for c in range(0, TILE_TXCOLS, 2):
110115
$if WEIGHT_STORAGE == "buffer":
111116
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
112-
const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
117+
packed_weight_tex = t_weight[qmat2_bufi + ${c}]
113118
$else:
114-
const uvec4 packed_weight_tex = texelFetch(
119+
packed_weight_tex = texelFetch(
115120
t_weight, u16vec2(weight_txcol + ${c}, pos + r), 0);
116121

117122
qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4) - 8.0);

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ linear_qcsnw_tiled:
1212
WEIGHT_STORAGE: texture2d
1313
SCALES_STORAGE: texture2d
1414
TILE_ROWS: 4
15-
TILE_TXCOLS: 1
15+
TILE_TXCOLS: 2
1616
QUANT_NBITS: 8
1717
generate_variant_forall:
1818
TILE_ROWS:

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ utils::uvec3 linear_qcsnw_tiled_global_wg_size(
7373
}
7474

7575
// Number of output texels in the output tile
76-
uint32_t out_tile_ntxcols = 1;
76+
uint32_t out_tile_ntxcols = 2;
7777
if (quant_nbits == 4) {
7878
out_tile_ntxcols = 2;
7979
}
@@ -324,7 +324,7 @@ void add_linear_qcsnw_tiled_node(
324324
}
325325

326326
// Number of output texels in the output tile
327-
uint32_t out_tile_ntxcols = 1;
327+
uint32_t out_tile_ntxcols = 2;
328328
if (quant_nbits == 4) {
329329
out_tile_ntxcols = 2;
330330
}

0 commit comments

Comments
 (0)