@@ -62,11 +62,12 @@ void main() {
6262 return ;
6363 }
6464
65- VEC4_T sums[TILE_ROWS][ TILE_TXCOLS];
65+ T sums[TILE_ROWS * TILE_TXCOLS * 4 ];
6666
6767 for (int r = 0 ; r < TILE_ROWS; ++ r) {
6868 $for c in range(TILE_TXCOLS):
69- sums[r][${c}] = VEC4_T(0.0 );
69+ $for j in range(4 ):
70+ sums[r * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] = T(0.0 );
7071 }
7172
7273 const int in_row_txstride = div4(in_sizes.x);
@@ -75,22 +76,16 @@ void main() {
7576 txpos < in_row_txstride;
7677 pos += 4 , txpos += 1 ) {
7778
78- T mat1[TILE_ROWS][ 4 ];
79+ T mat1[TILE_ROWS * 4 ];
7980
8081 // Preload input tensor
8182 for (int i = 0 ; i < TILE_ROWS; i++ ) {
8283 $if IN_STORAGE == "buffer ":
83- VEC4_T tmp = t_in[(out_row + i) * in_row_txstride + txpos];
84- mat1[i][0 ] = tmp.x;
85- mat1[i][1 ] = tmp.y;
86- mat1[i][2 ] = tmp.z;
87- mat1[i][3 ] = tmp.w;
84+ VEC4_T mat1_vec4 = t_in[(out_row + i) * in_row_txstride + txpos];
8885 $else :
89- VEC4_T tmp = VEC4_T(texelFetch(t_in, ivec3 (txpos, out_row + i, 0 ), 0 ));
90- mat1[i][0 ] = tmp.x;
91- mat1[i][1 ] = tmp.y;
92- mat1[i][2 ] = tmp.z;
93- mat1[i][3 ] = tmp.w;
86+ VEC4_T mat1_vec4 = VEC4_T(texelFetch(t_in, ivec3 (txpos, out_row + i, 0 ), 0 ));
87+ $for j in range(4 ):
88+ mat1[i * 4 + ${j}] = mat1_vec4[${j}];
9489 }
9590
9691 $if WEIGHT_STORAGE == "buffer ":
@@ -99,7 +94,9 @@ void main() {
9994
10095 // Preload weight tensor
10196 for (int r = 0 ; r < 4 ; r++ ) {
102- VEC4_T qmat2[TILE_TXCOLS];
97+ T qmat2[TILE_TXCOLS * 4 ];
98+ VEC4_T qmat2_vec4;
99+
103100 $if QUANT_NBITS == 4 :
104101 $if WEIGHT_STORAGE == "buffer ":
105102 u8vec4 packed_weight_tex;
@@ -114,20 +111,31 @@ void main() {
114111 packed_weight_tex = texelFetch(
115112 t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
116113
117- qmat2[${c}] = (VEC4_T(packed_weight_tex >> 4 ) - 8.0 );
118- qmat2[${c + 1 }] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
114+ qmat2_vec4 = (VEC4_T(packed_weight_tex >> 4 ) - 8.0 );
115+ qmat2[${c} * 4 * TILE_TXCOLS + 0 ] = qmat2_vec4.x;
116+ qmat2[${c} * 4 * TILE_TXCOLS + 1 ] = qmat2_vec4.y;
117+ qmat2[${c} * 4 * TILE_TXCOLS + 2 ] = qmat2_vec4.z;
118+ qmat2[${c} * 4 * TILE_TXCOLS + 3 ] = qmat2_vec4.w;
119+
120+ qmat2_vec4 = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
121+ qmat2[${c} * 4 * TILE_TXCOLS + 4 ] = qmat2_vec4.x;
122+ qmat2[${c} * 4 * TILE_TXCOLS + 5 ] = qmat2_vec4.y;
123+ qmat2[${c} * 4 * TILE_TXCOLS + 6 ] = qmat2_vec4.z;
124+ qmat2[${c} * 4 * TILE_TXCOLS + 7 ] = qmat2_vec4.w;
119125 $else :
120126 $for c in range(TILE_TXCOLS):
121127 $if WEIGHT_STORAGE == "buffer ":
122128 qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
123- qmat2[${c}] = t_weight[qmat2_bufi + ${c}];
129+ qmat2_vec4 = t_weight[qmat2_bufi + ${c}];
124130 $else :
125- qmat2[${c}] = VEC4_T(
126- texelFetch(t_weight, ivec2 (out_txcol + ${c}, pos + r), 0 ));
131+ qmat2_vec4 = VEC4_T(texelFetch(t_weight, ivec2 (out_txcol + ${c}, pos + r), 0 ));
132+ $for j in range(4 ):
133+ qmat2[${c} * 4 + ${j}] = qmat2_vec4[${j}];
127134
128135 for (int tr = 0 ; tr < TILE_ROWS; ++ tr) {
129136 $for c in range(TILE_TXCOLS):
130- sums[tr][${c}] += qmat2[${c}] * mat1[tr][r];
137+ $for j in range(4 ):
138+ sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r];
131139 }
132140 }
133141 }
@@ -146,16 +154,22 @@ void main() {
146154 uint out_row_txstride = div4(out_sizes.x);
147155
148156 for (int r = 0 ; r < TILE_ROWS; ++ r) {
157+ VEC4_T scaled_sums;
149158 $for c in range(TILE_TXCOLS):
159+ scaled_sums.x = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 0 ] * scales[${c}].x;
160+ scaled_sums.y = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 1 ] * scales[${c}].y;
161+ scaled_sums.z = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 2 ] * scales[${c}].z;
162+ scaled_sums.w = sums[r * TILE_TXCOLS * 4 + ${c} * 4 + 3 ] * scales[${c}].w;
163+
150164 $if OUT_STORAGE == "buffer ":
151165 if (out_row + r < out_sizes.y) {
152166 out_bufi = (out_row + r) * out_row_txstride + out_txcol;
153- t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}] ;
167+ t_out[out_bufi + ${c}] = scaled_sums ;
154168 }
155169 $else :
156170 imageStore(
157171 t_out,
158172 ivec3 (out_txcol + ${c}, out_row + r, 0 ),
159- sums[r][${c}] * scales[${c}] );
173+ scaled_sums );
160174 }
161175}
0 commit comments