Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.max_pool2d_with_indices.default,
# Sum
exir_ops.edge.aten.sum.dim_IntList,
# Convolution operators
exir_ops.edge.aten.convolution.default,
# Other
operator.getitem,
]
Expand Down
136 changes: 136 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#include "indexing_utils.h"

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in;
layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in;

layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents {
uvec4 data;
}
out_extents;

layout(set = 0, binding = 5) uniform PRECISION restrict InExtents {
uvec4 data;
}
in_extents;

layout(set = 0, binding = 6) uniform PRECISION restrict Params {
ivec2 kernel_size;
ivec2 stride;
ivec2 padding;
ivec2 dilation;
}
params;

// If fields are separated, SwiftShader cannot identify in_group_size.
layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams {
ivec2 overlay_region;
int in_group_size;
}
extra_params;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Computes a 2D convolution. Each shader invocation calculates the output at
* a single output location.
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, out_extents.data.xyz))) {
return;
}

// Compute the index of the top-left element of the overlay region. Negative
// indices indicate that the top-left element is in a region added by padding.
const ivec2 ipos = pos.xy * params.stride - params.padding;

// Compute the start and end of the input indices to load. Padding is assumed
// to be constant 0 padding, so reads from the padding region are skipped.
const ivec2 start = max(ivec2(0), ipos);
const ivec2 end = min(ipos + extra_params.overlay_region.xy, ivec2(in_extents.data.xy));
// Compute the start of the kernel based on how far we are skipping ahead when
// reading the input. Note that these are "canonical" indices.
ivec2 kstart = (start - ipos) / params.dilation;
// During prepacking, the weight tensor was rearranged in order to optimize
// for data access linearity in this shader. Therefore we need to adjust the
// canonical coordinates to the corresponding index in the rearranged weight
// tensor. The x-coordinate is multipled by 4 since each group of 4 channels
// is folded into the X axis. The y-coordinate is offset based on the z-
// coordinate because the 2D planes were stacked atop each other vertically.
kstart.x *= 4;
kstart.y += pos.z * params.kernel_size.y;

// Perform the convolution by iterating over the overlay region.
${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0);
const int ic4 = extra_params.in_group_size / 4;
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) {
for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) {
for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) {
const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0);

// To explain the calculation below, the contents of in_texel and the
// group of 4 texels loaded from kernel_in are shown:
//
// in_texel kernel_in
// -x-> ---x--->
// +---+ +----+----+----+----+
// ^ | w | ^ | D0 | D1 | D2 | D3 |
// | +---+ | +----+----+----+----+
// | | z | | | C0 | C1 | C2 | C3 |
// z +---+ z +----+----+----+----+
// | | y | | | B0 | B1 | B2 | B3 |
// | +---+ | +----+----+----+----+
// | x | | A0 | A1 | A2 | A3 |
// +---+ +----+----+----+----+
//
// In the kernel_in graphic, cells sharing the same letter are from
// the same batch/output channel index, and the number denotes a unique
// channel index. To calculate the output texel, the following
// calculation is performed:
//
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 |
// +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
// | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
// | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 |
// +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+
//
// which is expressed in the following statements.

const ${VEC4_T[DTYPE]} ktex_0 = texelFetch(kernel_in, ivec2(kx + 0, ky), 0);
sum = fma(in_texel.xxxx, ktex_0, sum);

const ${VEC4_T[DTYPE]} ktex_1 = texelFetch(kernel_in, ivec2(kx + 1, ky), 0);
sum = fma(in_texel.yyyy, ktex_1, sum);

const ${VEC4_T[DTYPE]} ktex_2 = texelFetch(kernel_in, ivec2(kx + 2, ky), 0);
sum = fma(in_texel.zzzz, ktex_2, sum);

const ${VEC4_T[DTYPE]} ktex_3 = texelFetch(kernel_in, ivec2(kx + 3, ky), 0);
sum = fma(in_texel.wwww, ktex_3, sum);
}
}
}

imageStore(image_out, pos, sum);
}
18 changes: 18 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

conv2d:
parameter_names_with_default_values:
NDIM: 3
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: conv2d
131 changes: 131 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#include "indexing_utils.h"

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out;
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
${T[DTYPE]} data[];
}
buffer_in;

// Corresponds to {1,4,9,24} in the example below.
layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes {
ivec4 data;
}
gpu_sizes;

// Corresponds to {3,3,7,10} in the example below.
layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes {
ivec4 data;
}
original_sizes;

// Corresponds to {3,3,8,12} in the example below.
layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes {
ivec2 data;
}
padded_sizes;

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

/*
* Computes special prepacking for a 2D convolution. Each shader invocation
* calculates the input buffer location to read into the desired texel. This
* packing was originally developed on CPU and that approach is described in the
* rest of this comment. Refer to the code-level comments, for how we translate
* it to GPU by reversing the steps.
*
* Consider an example weight tensor of size {10,7,3,3}. The following
* transformations will be applied.
*
* 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2
* batches and 1 channel of padding are added, producing a tensor of size
* {12,8,3,3}.
* at::pad(x, {0,0,0,0,0,2,0,1}, "constant", 0);
*
* 2. Split the tensor along the C dim so that each split has 4 channels.
* x.reshape({12,2,4,3,3});
*
* 3. For each split, "fold" the C dim into the W dim. Suppose the first rows
* at H=0 of the split have values
* 0,1,2 | 10,11,12 | 20,21,22 | 30,31,32
*
* where | denotes a channel boundary. Then, the goal is to combine those rows
* into one row with the values
* 0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32
*
* x.permute({0,1,3,4,2}).reshape({12,2,3,12});
*
* 4. Stack the splits belonging to the same batch horizontally by swapping the
* C and H dims.
* x.permute({0,2,1,3}).reshape({12,3,24});
*
* 5. Repeat a similar process to "fold" the N dim into the C dim. Split along
* the N dim so that each split has 4 batches.
* x.reshape({3,4,3,24});
*
* 6. Stack the batches on each other vertically by swapping the N and C dims.
* x.permute({1,0,2,3}).reshape({4,9,24});
*/
void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
const ivec4 coord = POS_TO_COORD_CHANNELS_PACKED(pos, gpu_sizes.data);

if (any(greaterThanEqual(coord, gpu_sizes.data))) {
return;
}

// As in usual staging shaders, map from GPU texel position to normal CPU
// buffer indices: (24,9) -> (4,9,24)
const int base_index = COORD_TO_BUFFER_IDX(coord, gpu_sizes.data);
const ivec4 p0 =
base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data);

// Re-map the normal CPU buffer indices to special indices, through a series
// of mappings: reshape is a no-op to the underlying indices, pad is hard, and
// permute is one of the hardest math problems I've ever solved.
const int Np = padded_sizes.data.y;
const int Cp = padded_sizes.data.x;
const int N = original_sizes.data.w;
const int C = original_sizes.data.z;
const int H = original_sizes.data.y;
const int W = original_sizes.data.x;

// Undo step 6 premute: (4,3,3,24) -> (3,4,3,24)
// Undo step 4 permute: (12,3,2,12) -> (12,2,3,12)
// Undo step 3 permute, part 1: (12,2,3h,3w,4) -> (12,2,3h,4,3w)
// Undo step 3 permute, part 2: (12,2,3h,4,3w) -> (12,2,4,3h,3w)
const ivec4 p1 = SWAP_ADJ_DIMS(p0, 4, (Np / 4), (H * Cp * W));
const ivec4 p2 = SWAP_ADJ_DIMS(p1, H, (Cp / 4), (W * 4));
const ivec4 p3 = SWAP_ADJ_DIMS(p2, W, 4, 1);
const ivec4 p4 = SWAP_ADJ_DIMS(p3, H, 4, W);

// Undo step 1 pad: (12,8,3,3) -> (10,7,3,3)
// For values in the padded region, write zero instead of buffer data.
const ivec4 c = p4 % (Cp * H * W) / (H * W);
const ivec4 n = p4 / (Cp * H * W);
const ivec4 p5 = p4 - n * (Cp - C) * H * W;
const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) |
ivec4(greaterThanEqual(n, ivec4(N)));

${T[DTYPE]} val_x = mix(buffer_in.data[p5.x], 0, mask.x);
${T[DTYPE]} val_y = mix(buffer_in.data[p5.y], 0, mask.y);
${T[DTYPE]} val_z = mix(buffer_in.data[p5.z], 0, mask.z);
${T[DTYPE]} val_w = mix(buffer_in.data[p5.w], 0, mask.w);

${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w);

imageStore(image_out, pos.xy, texel);
}
17 changes: 17 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

conv2d_prepack_weights:
parameter_names_with_default_values:
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: half
SUFFIX: half
- VALUE: float
SUFFIX: float
shader_variants:
- NAME: conv2d_prepack_weights
11 changes: 11 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,14 @@
#define STRIDE_WIDTH_PACKED(vec) (1)

#define STRIDE_HEIGHT_PACKED(vec) (vec.x)

// Given a buffer(1-D) index cur, compute a new index where the corresponding
// tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane
// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a
// tensor of shape {4,3,2,24} to obtain {3,4,2,24}. Then, x=4, y=3 and
// plane=2*24=48.
#define SWAP_ADJ_DIMS(cur, x, y, plane) \
cur + \
plane*( \
(1 - y) * ((cur % (x * y * plane)) / (y * plane)) + \
(x - 1) * ((cur % (y * plane)) / plane))
Loading