diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl new file mode 100644 index 00000000000..ab15658111f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; +layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; + +layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { + uvec4 data; +} +out_extents; + +layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { + uvec4 data; +} +in_extents; + +layout(set = 0, binding = 6) uniform PRECISION restrict Params { + ivec2 kernel_size; + ivec2 stride; + ivec2 padding; + ivec2 dilation; +} +params; + +// If fields are separated, SwiftShader cannot identify in_group_size. +layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { + ivec2 overlay_region; + int in_group_size; +} +extra_params; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes a depthwise convolution. Each shader invocation calculates the + * output at a single output location. + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + return; + } + + // Compute the index of the top-left element of the overlay region. Negative + // indices indicate that the top-left element is in a region added by padding. + const ivec2 ipos = pos.xy * params.stride - params.padding; + + // Compute the start and end of the input indices to load. Padding is assumed + // to be constant 0 padding, so reads from the padding region are skipped. + const ivec2 start = ipos; + const ivec2 end = ipos + extra_params.overlay_region.xy; + + ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); + int kx = 0; + for (int y = start.y; y < end.y; y += params.dilation.y) { + for (int x = start.x; x < end.x; x += params.dilation.x) { + // The weight kernel was rearranged so that every NxN filter is flattened + // to fits in one row. Each filter was then stacked on top of each other + // vertically. + const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0); + sum = fma(in_texel, texelFetch(kernel_in, ivec2(kx, pos.z), 0), sum); + ++kx; + } + } + + imageStore(image_out, pos, sum); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.yaml new file mode 100644 index 00000000000..560887f3dc1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_dw: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: conv2d_dw diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl new file mode 100644 index 00000000000..ef2b54ba354 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +// Corresponds to {1,4,3,9} in the example below. +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +// Corresponds to {3,3,1,11} in the example below. +layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { + ivec4 data; +} +original_sizes; + +// Corresponds to {1,12} in the example below. +layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { + ivec2 data; +} +padded_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes special prepacking for a depthwise convolution. Each shader invocation + * calculates the input buffer location to read into the desired texel. This + * packing was originally developed on CPU and that approach is described in the + * rest of this comment. Refer to the code-level comments, for how we translate + * it to GPU by reversing the steps. + * + * Consider an example weight tensor of size {11,1,3,3}. The following + * transformations will be applied. + * + * 1. Pad the N dim so that it is a multiple of 4. In this case, 1 + * batch of padding is added, producing a tensor of size {12,1,3,3}. + * at::pad(x, {0,0,0,0,0,0,0,1}, "constant", 0); + * + * 2. Flatten the last two dims by reshaping the tensor: + * x.reshape({12,1,9}); + * + * 3. "Fold" the N dim into the C dim. Split the tensor along the N dim so that + * each split has 4 channels. + * x.reshape({3,4,1,9}); + * + * 4. Stack the batches on each other vertically by permuting the N and C dims + * and reshaping the tensor. + * x.permute({1,0,2,3}).reshape({4,3,9}); + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_CHANNELS_PACKED(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + // As in usual staging shaders, map from GPU texel position to normal CPU + // buffer indices: (9,3) -> (4,3,9) + const int base_index = COORD_TO_BUFFER_IDX(coord, gpu_sizes.data); + const ivec4 p0 = + base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data); + + // Re-map the normal CPU buffer indices to special indices, through a series + // of mappings: reshape is a no-op to the underlying indices, so we only map + // for pad and permute. + const int Np = padded_sizes.data.x; + const int N = original_sizes.data.w; + const int C = original_sizes.data.z; + const int H = original_sizes.data.y; + const int W = original_sizes.data.x; + + // Undo step 3 permute: (4,3,1,9) -> (3,4,1,9) + const ivec4 p1 = SWAP_ADJ_DIMS(p0, 4, (Np / 4), (C * H * W)); + + // Undo step 1 pad: (12,1,3,3) -> (11,1,3,3) + // For values in the padded region, write zero instead of buffer data. + const ivec4 n = p1 / (C * H * W); + const ivec4 mask = ivec4(greaterThanEqual(n, ivec4(N))); + + ${T[DTYPE]} val_x = mix(buffer_in.data[p1.x], 0, mask.x); + ${T[DTYPE]} val_y = mix(buffer_in.data[p1.y], 0, mask.y); + ${T[DTYPE]} val_z = mix(buffer_in.data[p1.z], 0, mask.z); + ${T[DTYPE]} val_w = mix(buffer_in.data[p1.w], 0, mask.w); + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + imageStore(image_out, pos.xy, texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml new file mode 100644 index 00000000000..e7fc5f797c8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_dw_prepack_weights: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: conv2d_dw_prepack_weights diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp index 5d4b36f03f3..d0b7c89ea5a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp @@ -80,15 +80,27 @@ ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) { return v; } +enum class Conv2dMethod : uint8_t { + Depthwise, + SlidingWindow, + Transposed, +}; + api::ShaderInfo get_conv2d_shader( const vTensor& t_out, const bool prepack_weights, - const bool transposed) { + const Conv2dMethod method) { std::stringstream kernel_name; - if (transposed) { - kernel_name << "conv_transpose2d"; - } else { - kernel_name << "conv2d"; + switch (method) { + case Conv2dMethod::Depthwise: + kernel_name << "conv2d_dw"; + break; + case Conv2dMethod::SlidingWindow: + kernel_name << "conv2d"; + break; + case Conv2dMethod::Transposed: + kernel_name << "conv_transpose2d"; + break; } if (prepack_weights) { kernel_name << "_prepack_weights"; @@ -98,23 +110,53 @@ api::ShaderInfo get_conv2d_shader( return VK_KERNEL_FROM_STR(kernel_name.str()); } -ValueRef prepack_weights( - ComputeGraph& graph, - const ValueRef vref, - const bool transposed) { - const auto original_sizes = graph.get_val(vref).toTensorRef().sizes; - +std::vector get_final_sizes( + const std::vector& original_sizes, + const Conv2dMethod method) { int64_t batch_padded = api::utils::align_up(api::utils::val_at(-4, original_sizes), INT64_C(4)); int64_t channels_padded = api::utils::align_up(api::utils::val_at(-3, original_sizes), INT64_C(4)); + int64_t channels = api::utils::val_at(-3, original_sizes); int64_t height = api::utils::val_at(-2, original_sizes); int64_t width = api::utils::val_at(-1, original_sizes); - const auto final_sizes = std::vector{ - 4, - transposed ? channels_padded * height / 4 : batch_padded * height / 4, - transposed ? batch_padded * width : channels_padded * width}; + switch (method) { + case Conv2dMethod::Depthwise: + return std::vector{ + 4, batch_padded * channels / 4, height * width}; + case Conv2dMethod::SlidingWindow: + return std::vector{ + 4, batch_padded * height / 4, channels_padded * width}; + case Conv2dMethod::Transposed: + return std::vector{ + 4, channels_padded * height / 4, batch_padded * width}; + } +} + +std::vector get_padded_sizes( + const std::vector& original_sizes, + const Conv2dMethod method) { + int64_t batch_padded = + api::utils::align_up(api::utils::val_at(-4, original_sizes), INT64_C(4)); + int64_t channels_padded = + api::utils::align_up(api::utils::val_at(-3, original_sizes), INT64_C(4)); + + switch (method) { + case Conv2dMethod::Depthwise: + return std::vector{-1, batch_padded}; + case Conv2dMethod::SlidingWindow: + case Conv2dMethod::Transposed: + return std::vector{batch_padded, channels_padded}; + } +} + +ValueRef prepack_weights( + ComputeGraph& graph, + const ValueRef vref, + const Conv2dMethod method) { + const auto original_sizes = graph.get_val(vref).toTensorRef().sizes; + const auto final_sizes = get_final_sizes(original_sizes, method); ValueRef v = graph.add_tensor( final_sizes, @@ -127,9 +169,9 @@ ValueRef prepack_weights( api::utils::uvec3 local_size = adaptive_work_group_size(global_size); api::ShaderInfo shader = - get_conv2d_shader(t, /*prepack_weights = */ true, transposed); + get_conv2d_shader(t, /*prepack_weights = */ true, method); - const auto padded_sizes = std::vector{batch_padded, channels_padded}; + const auto padded_sizes = get_padded_sizes(original_sizes, method); graph.prepack_nodes().emplace_back(new PrepackNode( graph, @@ -197,6 +239,24 @@ void check_conv2d_params(const KernelParams& p, const bool transposed) { } } +Conv2dMethod get_conv2d_method( + ComputeGraph& graph, + const ValueRef weight, + const int64_t groups, + const bool transposed) { + const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes; + if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) { + return Conv2dMethod::Depthwise; + } + if (groups > 1) { + VK_THROW("aten.convolution.default: groups > 1 is not supported yet!"); + } + if (transposed) { + return Conv2dMethod::Transposed; + } + return Conv2dMethod::SlidingWindow; +} + void add_conv2d_node( ComputeGraph& graph, const ValueRef in, @@ -207,11 +267,16 @@ void add_conv2d_node( const ValueRef dilation, const ValueRef transposed, const ValueRef output_padding, + const ValueRef groups, const ValueRef out) { const bool transposed_val = graph.get_val(transposed).toBool(); + const int64_t groups_val = graph.get_val(groups).toInt(); + + const Conv2dMethod method = + get_conv2d_method(graph, weight, groups_val, transposed_val); ValueRef arg_in = prepack_if_tensor_ref(graph, in); - ValueRef arg_weight = prepack_weights(graph, weight, transposed_val); + ValueRef arg_weight = prepack_weights(graph, weight, method); ValueRef arg_bias = prepack_biases(graph, bias); vTensor& t_in = graph.get_val(arg_in).toTensor(); @@ -234,7 +299,7 @@ void add_conv2d_node( check_conv2d_params(kernel_params, transposed_val); api::ShaderInfo shader = - get_conv2d_shader(t_out, /*prepack_weights = */ false, transposed_val); + get_conv2d_shader(t_out, /*prepack_weights = */ false, method); graph.execute_nodes().emplace_back(new ExecuteNode( graph, @@ -257,10 +322,6 @@ void add_conv2d_node( } void conv2d(ComputeGraph& graph, const std::vector& args) { - const int64_t groups = graph.get_val(args[8]).toInt(); - if (groups > 1) { - VK_THROW("aten.convolution.default: groups > 1 is not supported yet!"); - } return add_conv2d_node( graph, args[0], @@ -271,6 +332,7 @@ void conv2d(ComputeGraph& graph, const std::vector& args) { args[5], args[6], args[7], + args[8], args[9]); } diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 37b0b691b3f..e8220d8522d 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -551,3 +551,28 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_conv2d_dw(self): + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=8, + out_channels=8, + kernel_size=3, + padding=1, + groups=8, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + conv2d_module = Conv2dModule() + sample_inputs = (torch.randn(size=(1, 8, 72, 96), dtype=torch.float32),) + + self.lower_module_and_test_output( + conv2d_module, + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + )