|  | 
| 12 | 12 | 
 | 
| 13 | 13 | #define PRECISION ${PRECISION} | 
| 14 | 14 | 
 | 
| 15 |  | -#define VEC4_T ${texel_type(DTYPE)} | 
|  | 15 | +$if DTYPE == "half": | 
|  | 16 | +  #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | 
|  | 17 | +  #define VEC4_T f16vec4 | 
|  | 18 | +$else: | 
|  | 19 | +  #define VEC4_T ${texel_type(DTYPE)} | 
| 16 | 20 | 
 | 
| 17 |  | -#define TILE_SIZE_X uint16_t(${TILE_SIZE_X}) | 
| 18 |  | -#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y}) | 
| 19 | 21 | 
 | 
| 20 | 22 | #define op(X, A, B) ${OPERATOR} | 
| 21 | 23 | 
 | 
| @@ -50,119 +52,90 @@ ${layout_declare_spec_const(C, "int", "ngroups", "1")} | 
| 50 | 52 |  * size is only 1x1, making it easier to re-use loaded texels from t_kernel. | 
| 51 | 53 |  */ | 
| 52 | 54 | void main() { | 
| 53 |  | -  const int out_limits_scaled[2] = | 
| 54 |  | -    {(out_limits.x + (TILE_SIZE_X - 1)) / TILE_SIZE_X, | 
| 55 |  | -     (out_limits.y + (TILE_SIZE_Y - 1)) / TILE_SIZE_Y}; | 
| 56 | 55 | 
 | 
| 57 |  | -  const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]); | 
| 58 |  | -  const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x}; | 
| 59 |  | -  const int out_pos_z = int(gl_GlobalInvocationID.y); | 
|  | 56 | +  int inputAndOutputWidth = out_limits.x; | 
|  | 57 | +  int inputAndOutputHeight = out_limits.y; | 
|  | 58 | +  int outputChannel = out_limits.z*4; | 
| 60 | 59 | 
 | 
| 61 |  | -  // If the top left position is out of bounds, then this invocation will have | 
| 62 |  | -  // no work to do. | 
| 63 |  | -  if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) { | 
|  | 60 | +  // Divided by 4 because the input channels are packed | 
|  | 61 | +  int inputChannel = in_group_size/4; | 
|  | 62 | + | 
|  | 63 | +  int threadHW = int(gl_GlobalInvocationID.x); | 
|  | 64 | +  int threadOutChannel = int(gl_GlobalInvocationID.y); | 
|  | 65 | + | 
|  | 66 | +  int xIdx = threadHW % inputAndOutputWidth; | 
|  | 67 | +  int yIdx = threadHW / inputAndOutputWidth; | 
|  | 68 | + | 
|  | 69 | +  if (threadHW >= inputAndOutputWidth * inputAndOutputHeight && threadOutChannel >= outputChannel) { | 
| 64 | 70 |     return; | 
| 65 | 71 |   } | 
| 66 | 72 | 
 | 
| 67 |  | -  // Output position for TILE_SIZE = 2 | 
| 68 |  | -  // +--------+--------+ | 
| 69 |  | -  // | pos[0] | pos[1] | | 
| 70 |  | -  // +--------+--------+ | 
| 71 |  | -  // | pos[2] | pos[3] | | 
| 72 |  | -  // +--------+--------+ | 
| 73 |  | -  uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; | 
| 74 |  | -  for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) { | 
| 75 |  | -    for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) { | 
| 76 |  | -      pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x; | 
| 77 |  | -      pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y; | 
| 78 |  | -      i++; | 
| 79 |  | -    } | 
| 80 |  | -  } | 
|  | 73 | +  VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0)); | 
| 81 | 74 | 
 | 
| 82 |  | -  // Final output array where each element is a tensor value. | 
| 83 |  | -  // Tuple of consecutive 4 elements represents a single output texel. | 
| 84 |  | -  float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; | 
|  | 75 | +  VEC4_T inputVec; | 
|  | 76 | +  VEC4_T weight1OutputChannelPacked; | 
|  | 77 | +  VEC4_T weight2OutputChannelPacked; | 
|  | 78 | +  VEC4_T weight3OutputChannelPacked; | 
|  | 79 | +  VEC4_T weight4OutputChannelPacked; | 
| 85 | 80 | 
 | 
| 86 |  | -  // Initialize the output array with the bias value | 
| 87 |  | -  for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) { | 
| 88 |  | -    sum[i] = 0; | 
| 89 |  | -  } | 
|  | 81 | +  // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions | 
|  | 82 | +  // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute | 
|  | 83 | +  for (int inputC = 0; inputC < inputChannel; inputC += 1) { | 
| 90 | 84 | 
 | 
| 91 |  | -  int z4 = 0; | 
| 92 |  | -  // Since the kernel is 1x1, we only have to loop over the depth dimension. | 
| 93 |  | -  for (int z = 0; z < in_group_size; z += 4, ++z4) { | 
| 94 |  | -    // During prepacking, the weight tensor has been permuted so that the | 
| 95 |  | -    // channel (IC) dim is along the x-axis, and the batch (OC) dim is along | 
| 96 |  | -    // the z-axis. | 
| 97 |  | -    float kernel_values[4 * 4]; // 4 channels, 4 elements per channel | 
| 98 |  | - | 
| 99 |  | -    // Load kernel values from texels to array | 
| 100 |  | -    [[unroll]] for (int i = 0; i < 4; ++i) { | 
| 101 |  | -      const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0); | 
| 102 |  | -      kernel_values[i * 4 + 0] = k_tex.x; | 
| 103 |  | -      kernel_values[i * 4 + 1] = k_tex.y; | 
| 104 |  | -      kernel_values[i * 4 + 2] = k_tex.z; | 
| 105 |  | -      kernel_values[i * 4 + 3] = k_tex.w; | 
| 106 |  | -    } | 
| 107 |  | - | 
| 108 |  | -    for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { | 
| 109 |  | -      const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0); | 
| 110 |  | -      // Load the input texel into an array | 
| 111 |  | -      float tex_values[4]; | 
| 112 |  | -      tex_values[0] = in_tex.x; | 
| 113 |  | -      tex_values[1] = in_tex.y; | 
| 114 |  | -      tex_values[2] = in_tex.z; | 
| 115 |  | -      tex_values[3] = in_tex.w; | 
| 116 |  | - | 
| 117 |  | -      // For 2x2 tile size algorithm works as follows. | 
| 118 |  | -      // To explain the calculations below, the contents of one in_tex and the | 
| 119 |  | -      // group of 4 texels loaded from t_kernel are shown: | 
| 120 |  | -      // | 
| 121 |  | -      //   in_tex                 t_kernel | 
| 122 |  | -      //    -x->                   ---x---> | 
| 123 |  | -      //   +---+              +----+----+----+----+ | 
| 124 |  | -      // ^ | w |           ^  | D0 | D1 | D2 | D3 | | 
| 125 |  | -      // | +---+           |  +----+----+----+----+ | 
| 126 |  | -      // | | z |           |  | C0 | C1 | C2 | C3 | | 
| 127 |  | -      // z +---+           z  +----+----+----+----+ | 
| 128 |  | -      // | | y |           |  | B0 | B2 | B2 | B3 | | 
| 129 |  | -      // | +---+           |  +----+----+----+----+ | 
| 130 |  | -      //   | x |              | A0 | A1 | A2 | A3 | | 
| 131 |  | -      //   +---+              +----+----+----+----+ | 
| 132 |  | -      // | 
| 133 |  | -      // In the t_kernel graphic, cells sharing the same letter are from | 
| 134 |  | -      // the same batch/output channel index, and the number denotes a unique | 
| 135 |  | -      // channel index. To calculate the output texel, the following | 
| 136 |  | -      // calculation is performed: | 
| 137 |  | -      // | 
| 138 |  | -      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+ | 
| 139 |  | -      //  | x | | D0 |   | y | | D1 |   | z | | D2 |   | w | | D3 | | 
| 140 |  | -      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+ | 
| 141 |  | -      //  | x | | C0 |   | y | | C1 |   | z | | C2 |   | w | | C3 | | 
| 142 |  | -      //  +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ | 
| 143 |  | -      //  | x | | B0 |   | y | | B1 |   | z | | B2 |   | w | | B3 | | 
| 144 |  | -      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+ | 
| 145 |  | -      //  | x | | A0 |   | y | | A1 |   | z | | A2 |   | w | | A3 | | 
| 146 |  | -      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+ | 
| 147 |  | -      // | 
| 148 |  | -      //  which is what is expressed in the following calculations. This is done | 
| 149 |  | -      //  for each output position. | 
| 150 |  | -      for (int j = 0; j < 4; ++j) { | 
| 151 |  | -        sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j]; | 
| 152 |  | -        sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j]; | 
| 153 |  | -        sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j]; | 
| 154 |  | -        sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j]; | 
| 155 |  | -      } | 
| 156 |  | -    } | 
| 157 |  | -  } | 
|  | 85 | +    inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); | 
|  | 86 | + | 
|  | 87 | +    weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); | 
|  | 88 | +    weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); | 
|  | 89 | +    weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); | 
|  | 90 | +    weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); | 
|  | 91 | + | 
|  | 92 | +    outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); | 
|  | 93 | +    outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); | 
|  | 94 | +    outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); | 
|  | 95 | +    outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); | 
|  | 96 | + | 
|  | 97 | +    inputC += 1; | 
|  | 98 | + | 
|  | 99 | +    inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); | 
| 158 | 100 | 
 | 
| 159 |  | -  const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0); | 
|  | 101 | +    weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); | 
|  | 102 | +    weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); | 
|  | 103 | +    weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); | 
|  | 104 | +    weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); | 
| 160 | 105 | 
 | 
| 161 |  | -  for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { | 
| 162 |  | -    const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z); | 
| 163 |  | -    if (all(lessThan(pos_l.xy, out_limits.xy))) { | 
| 164 |  | -      const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]); | 
| 165 |  | -      imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max)); | 
| 166 |  | -    } | 
|  | 106 | +    outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); | 
|  | 107 | +    outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); | 
|  | 108 | +    outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); | 
|  | 109 | +    outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); | 
|  | 110 | + | 
|  | 111 | +    inputC += 1; | 
|  | 112 | + | 
|  | 113 | +    inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); | 
|  | 114 | + | 
|  | 115 | +    weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); | 
|  | 116 | +    weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); | 
|  | 117 | +    weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); | 
|  | 118 | +    weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); | 
|  | 119 | + | 
|  | 120 | +    outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); | 
|  | 121 | +    outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); | 
|  | 122 | +    outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); | 
|  | 123 | +    outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); | 
|  | 124 | + | 
|  | 125 | +    inputC += 1; | 
|  | 126 | + | 
|  | 127 | +    inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); | 
|  | 128 | + | 
|  | 129 | +    weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); | 
|  | 130 | +    weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); | 
|  | 131 | +    weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); | 
|  | 132 | +    weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); | 
|  | 133 | + | 
|  | 134 | +    outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); | 
|  | 135 | +    outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); | 
|  | 136 | +    outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); | 
|  | 137 | +    outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); | 
| 167 | 138 |   } | 
|  | 139 | + | 
|  | 140 | +  imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); | 
| 168 | 141 | } | 
0 commit comments