diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index b5df34b08cd..a4cf74097c4 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -48,6 +48,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.max_pool2d_with_indices.default, # Sum exir_ops.edge.aten.sum.dim_IntList, + # Convolution operators + exir_ops.edge.aten.convolution.default, # Other operator.getitem, ] diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl new file mode 100644 index 00000000000..30051e5f5a3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -0,0 +1,130 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; +layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; + +layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { + uvec4 data; +} +out_extents; + +layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { + uvec4 data; +} +in_extents; + +layout(set = 0, binding = 6) uniform PRECISION restrict Params { + ivec2 kernel_size; + ivec2 stride; + ivec2 padding; + ivec2 dilation; +} +params; + +// If fields are separated, SwiftShader cannot identify in_group_size. +layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { + ivec2 overlay_region; + int in_group_size; +} +extra_params; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes a 2D convolution. Each shader invocation calculates the output at + * a single output location. + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + return; + } + + // Compute the index of the top-left element of the overlay region. Negative + // indices indicate that the top-left element is in a region added by padding. + const ivec2 ipos = pos.xy * params.stride - params.padding; + + // Compute the start and end of the input indices to load. Padding is assumed + // to be constant 0 padding, so reads from the padding region are skipped. + const ivec2 start = max(ivec2(0), ipos); + const ivec2 end = min(ipos + extra_params.overlay_region.xy, ivec2(in_extents.data.xy)); + // Compute the start of the kernel based on how far we are skipping ahead when + // reading the input. Note that these are "canonical" indices. + ivec2 kstart = (start - ipos) / params.dilation; + // During prepacking, the weight tensor was rearranged in order to optimize + // for data access linearity in this shader. Therefore we need to adjust the + // canonical coordinates to the corresponding index in the rearranged weight + // tensor. The x-coordinate is multipled by 4 since each group of 4 channels + // is folded into the X axis. The y-coordinate is offset based on the z- + // coordinate because the 2D planes were stacked atop each other vertically. + kstart.x *= 4; + kstart.y += pos.z * params.kernel_size.y; + + // Perform the convolution by iterating over the overlay region. + ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); + const int ic4 = extra_params.in_group_size / 4; + for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) { + for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) { + for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) { + const ${VEC4_T[DTYPE]} in_texel = texelFetch(image_in, ivec3(x, y, z4), 0); + const ivec4 kxs = kx + ivec4(0, 1, 2, 3); + + // To explain the calculation below, the contents of in_texel and the + // group of 4 texels loaded from kernel_in are shown: + // + // in_texel kernel_in + // -x-> ---x---> + // +---+ +----+----+----+----+ + // ^ | w | ^ | D0 | D1 | D2 | D3 | + // | +---+ | +----+----+----+----+ + // | | z | | | C0 | C1 | C2 | C3 | + // z +---+ z +----+----+----+----+ + // | | y | | | B0 | B1 | B2 | B3 | + // | +---+ | +----+----+----+----+ + // | x | | A0 | A1 | A2 | A3 | + // +---+ +----+----+----+----+ + // + // In the kernel_in graphic, cells sharing the same letter are from + // the same batch/output channel index, and the number denotes a unique + // channel index. To calculate the output texel, the following + // calculation is performed: + // + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | + // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ + // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | + // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ + // + // which is expressed in the following statements. + + sum = fma(in_texel.xxxx, texelFetch(kernel_in, ivec2(kxs.x, ky), 0), sum); + sum = fma(in_texel.yyyy, texelFetch(kernel_in, ivec2(kxs.y, ky), 0), sum); + sum = fma(in_texel.zzzz, texelFetch(kernel_in, ivec2(kxs.z, ky), 0), sum); + sum = fma(in_texel.wwww, texelFetch(kernel_in, ivec2(kxs.w, ky), 0), sum); + } + } + } + + imageStore(image_out, pos, sum); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml new file mode 100644 index 00000000000..6764a2daa75 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: conv2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl new file mode 100644 index 00000000000..26b4fa0d76f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +// Corresponds to {1,4,9,24} in the example below. +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +// Corresponds to {3,3,7,10} in the example below. +layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { + ivec4 data; +} +original_sizes; + +// Corresponds to {8,12} in the example below. +layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { + ivec2 data; +} +padded_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes special prepacking for a 2D convolution. Each shader invocation + * calculates the input buffer location to read into the desired texel. This + * packing was originally developed on CPU and that approach is described in the + * rest of this comment. Refer to the code-level comments, for how we translate + * it to GPU by reversing the steps. + * + * Consider an example weight tensor of size {10,7,3,3}. The following + * transformations will be applied. + * + * 1. Pad the N and C dims so that both are a multiple of 4. In this case, 2 + * batches and 1 channel of padding are added, producing a tensor of size + * {12,8,3,3}. + * at::pad(x, {0,0,0,0,0,1,0,2}, "constant", 0); + * + * 2. Split the tensor along the C dim so that each split has 4 channels. + * x.reshape({12,2,4,3,3}); + * + * 3. For each split, "fold" the C dim into the W dim. Suppose the first rows + * at H=0 of the split have values + * 0,1,2 | 10,11,12 | 20,21,22 | 30,31,32 + * + * where | denotes a channel boundary. Then, the goal is to combine those rows + * into one row with the values + * 0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32 + * + * x.permute({0,1,3,4,2}).reshape({12,2,3,12}); + * + * 4. Stack the splits belonging to the same batch horizontally by swapping the + * C and H dims. + * x.permute({0,2,1,3}).reshape({12,3,24}); + * + * 5. Repeat a similar process to "fold" the N dim into the C dim. Split along + * the N dim so that each split has 4 batches. + * x.reshape({3,4,3,24}); + * + * 6. Stack the batches on each other vertically by swapping the N and C dims. + * x.permute({1,0,2,3}).reshape({4,9,24}); + */ +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_CHANNELS_PACKED(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + // As in usual staging shaders, map from GPU texel position to normal CPU + // buffer indices: (24,9) -> (4,9,24) + const int base_index = COORD_TO_BUFFER_IDX(coord, gpu_sizes.data); + const ivec4 p0 = + base_index + ivec4(0, 1, 2, 3) * STRIDE_CHANNELS_PACKED(gpu_sizes.data); + + // Re-map the normal CPU buffer indices to special indices, through a series + // of mappings: reshape is a no-op to the underlying indices, so we only map + // for pad and permute. + const int Np = padded_sizes.data.y; + const int Cp = padded_sizes.data.x; + const int N = original_sizes.data.w; + const int C = original_sizes.data.z; + const int H = original_sizes.data.y; + const int W = original_sizes.data.x; + + // Undo step 6 premute: (4,3,3,24) -> (3,4,3,24) + // Undo step 4 permute: (12,3,2,12) -> (12,2,3,12) + // Undo step 3 permute, part 1: (12,2,3h,3w,4) -> (12,2,3h,4,3w) + // Undo step 3 permute, part 2: (12,2,3h,4,3w) -> (12,2,4,3h,3w) + const ivec4 p1 = SWAP_ADJ_DIMS(p0, 4, (Np / 4), (H * Cp * W)); + const ivec4 p2 = SWAP_ADJ_DIMS(p1, H, (Cp / 4), (W * 4)); + const ivec4 p3 = SWAP_ADJ_DIMS(p2, W, 4, 1); + const ivec4 p4 = SWAP_ADJ_DIMS(p3, H, 4, W); + + // Undo step 1 pad: (12,8,3,3) -> (10,7,3,3) + // For values in the padded region, write zero instead of buffer data. + const ivec4 c = p4 % (Cp * H * W) / (H * W); + const ivec4 n = p4 / (Cp * H * W); + const ivec4 p5 = p4 - n * (Cp - C) * H * W; + const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) | + ivec4(greaterThanEqual(n, ivec4(N))); + + ${T[DTYPE]} val_x = mix(buffer_in.data[p5.x], 0, mask.x); + ${T[DTYPE]} val_y = mix(buffer_in.data[p5.y], 0, mask.y); + ${T[DTYPE]} val_z = mix(buffer_in.data[p5.z], 0, mask.z); + ${T[DTYPE]} val_w = mix(buffer_in.data[p5.w], 0, mask.w); + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + imageStore(image_out, pos.xy, texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml new file mode 100644 index 00000000000..277df2619ff --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv2d_prepack_weights: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: conv2d_prepack_weights diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index c76f054ec67..a5ed5b5f182 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -44,3 +44,14 @@ #define STRIDE_WIDTH_PACKED(vec) (1) #define STRIDE_HEIGHT_PACKED(vec) (vec.x) + +// Given a buffer(1-D) index cur, compute a new index where the corresponding +// tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane +// describe sizes. As an example, let's say we want to swap dimensions 0,1 for a +// tensor of shape {4,3,2,24} to obtain {3,4,2,24}. Then, x=4, y=3 and +// plane=2*24=48. +#define SWAP_ADJ_DIMS(cur, x, y, plane) \ + cur + \ + plane*( \ + (1 - y) * ((cur % (x * y * plane)) / (y * plane)) + \ + (x - 1) * ((cur % (y * plane)) / plane)) diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp new file mode 100644 index 00000000000..03f9b40c2f3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp @@ -0,0 +1,251 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include +#include + +#include + +namespace vkcompute { + +void resize_conv2d_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensor& out = graph->get_val(args[0].refs[0]).toTensor(); + vTensor& self = graph->get_val(args[1].refs[0]).toTensor(); + + size_t ndim = self.sizes().size(); + std::vector new_out_sizes(ndim); + + // Batch, Channel + if (ndim == 4) { + new_out_sizes.at(ndim - 4) = self.sizes().at(ndim - 4); + } + const auto weight_sizes = graph->get_val(extra_args[0]).toTensorRef().sizes; + new_out_sizes.at(ndim - 3) = weight_sizes.at(ndim - 4); + + // Height, Width + const auto new_out_sizes_hw = calc_out_sizes_hw( + *graph, + self.sizes(), + extra_args[0], + /*kernel_size_only = */ false, + extra_args[1], + extra_args[2], + extra_args[3]); + new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0); + new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1); + + out.virtual_resize(new_out_sizes); +} + +ValueRef prepack_biases(ComputeGraph& graph, const ValueRef vref) { + if (graph.get_val(vref).isNone()) { + VK_THROW("aten.convolution.default: Null bias is not supported yet!"); + } + + ValueRef v = graph.add_tensor_like( + vref, + api::StorageType::TEXTURE_2D, + api::GPUMemoryLayout::TENSOR_WIDTH_PACKED); + vTensor& t = graph.get_val(v).toTensor(); + + api::ShaderInfo shader = get_nchw_to_image_shader(t); + + api::utils::uvec3 global_size = t.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + shader, + global_size, + local_size, + vref, + v, + {t.gpu_sizes_ubo(), t.cpu_sizes_ubo()})); + + return v; +} + +api::ShaderInfo get_conv2d_shader(const vTensor& t_out, bool prepack_weights) { + std::stringstream kernel_name; + kernel_name << "conv2d"; + if (prepack_weights) { + kernel_name << "_prepack_weights"; + } + apply_dtype_suffix(kernel_name, t_out); + + return VK_KERNEL_FROM_STR(kernel_name.str()); +} + +ValueRef prepack_weights(ComputeGraph& graph, const ValueRef vref) { + const auto original_sizes = graph.get_val(vref).toTensorRef().sizes; + + int64_t batch_padded = + api::utils::align_up(api::utils::val_at(-4, original_sizes), INT64_C(4)); + int64_t channels_padded = + api::utils::align_up(api::utils::val_at(-3, original_sizes), INT64_C(4)); + int64_t height = api::utils::val_at(-2, original_sizes); + int64_t width = api::utils::val_at(-1, original_sizes); + + const auto final_sizes = std::vector{ + 4, batch_padded * height / 4, channels_padded * width}; + + ValueRef v = graph.add_tensor( + final_sizes, + graph.get_val(vref).toTensorRef().dtype, + api::StorageType::TEXTURE_2D, + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); + vTensor& t = graph.get_val(v).toTensor(); + + api::utils::uvec3 global_size = t.extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + api::ShaderInfo shader = get_conv2d_shader(t, /*prepack_weights = */ true); + + const auto padded_sizes = std::vector{batch_padded, channels_padded}; + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + shader, + global_size, + local_size, + vref, + v, + {t.gpu_sizes_ubo(), + graph.create_params_buffer( + api::utils::make_ivec4(original_sizes, /*reverse = */ true)), + graph.create_params_buffer( + api::utils::make_ivec2(padded_sizes, /*reverse = */ true))})); + + return v; +} + +void check_conv2d_args(const vTensor& in, const vTensor& out) { + if (in.sizes().at(0) > 1) { + VK_THROW( + "aten.convolution.default: input batch size > 1 is not supported yet!"); + } + VK_CHECK_COND( + check_memory_layout_is(in, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED)); + VK_CHECK_COND(check_memory_layout_is( + out, api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED)); +} + +struct Conv2dParams final { + api::utils::ivec2 overlay_region; + int in_group_size; +}; + +Conv2dParams create_conv2d_params( + ComputeGraph& graph, + const ValueRef weight, + const KernelParams& p) { + const auto overlay_region = api::utils::make_ivec2({ + p.kernel_size.data[0] + + (p.kernel_size.data[0] - 1) * (p.dilation.data[0] - 1), + p.kernel_size.data[1] + + (p.kernel_size.data[1] - 1) * (p.dilation.data[1] - 1), + }); + const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes; + const int32_t in_group_size = api::utils::safe_downcast( + api::utils::align_up(weight_sizes.at(1), INT64_C(4))); + return {overlay_region, in_group_size}; +} + +void check_conv2d_params(const KernelParams& p) { + if ((p.padding.data[0] > 0 && p.kernel_size.data[0] > 1 && + p.dilation.data[0] > 1) || + (p.padding.data[1] > 0 && p.kernel_size.data[1] > 1 && + p.dilation.data[1] > 1)) { + VK_THROW( + "aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!"); + } +} + +void add_conv2d_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef out) { + ValueRef arg_in = prepack_if_tensor_ref(graph, in); + ValueRef arg_weight = prepack_weights(graph, weight); + ValueRef arg_bias = prepack_biases(graph, bias); + + vTensor& t_in = graph.get_val(arg_in).toTensor(); + vTensor& t_out = graph.get_val(out).toTensor(); + + check_conv2d_args(t_in, t_out); + + api::utils::uvec3 global_size = t_out.virtual_extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + KernelParams kernel_params = create_kernel_params( + graph, + weight, + /*kernel_size_only = */ false, + stride, + padding, + dilation); + Conv2dParams extra_params = + create_conv2d_params(graph, weight, kernel_params); + + check_conv2d_params(kernel_params); + + api::ShaderInfo shader = + get_conv2d_shader(t_out, /*prepack_weights = */ false); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + shader, + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, + {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, + // Shader params buffers + { + t_out.extents_ubo(), + t_in.extents_ubo(), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(extra_params), + }, + // Resizing + resize_conv2d_node, + {weight, stride, padding, dilation})); +} + +void conv2d(ComputeGraph& graph, const std::vector& args) { + const bool transposed = graph.get_val(args[6]).toBool(); + if (transposed) { + VK_THROW("aten.convolution.default: transpose is not supported yet!"); + } + const int64_t groups = graph.get_val(args[8]).toInt(); + if (groups > 1) { + VK_THROW("aten.convolution.default: groups > 1 is not supported yet!"); + } + return add_conv2d_node( + graph, args[0], args[1], args[2], args[3], args[4], args[5], args[9]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.convolution.default, conv2d); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 1b2276ad7cc..d5f16cd98a8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -39,6 +39,7 @@ void resize_max_pool2d_node( *graph, self.sizes(), extra_args[0], + /*kernel_size_only = */ true, extra_args[1], extra_args[2], extra_args[3], @@ -81,8 +82,13 @@ void add_max_pool2d_node( kernel_name << "max_pool2d"; apply_dtype_suffix(kernel_name, t_out); - KernelParams kernel_params = - create_kernel_params(graph, kernel_size, stride, padding, dilation); + KernelParams kernel_params = create_kernel_params( + graph, + kernel_size, + /*kernel_size_only = */ true, + stride, + padding, + dilation); graph.execute_nodes().emplace_back(new ExecuteNode( graph, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp index d1d006f39f9..de55b296b9a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp @@ -15,14 +15,27 @@ api::utils::ivec2 make_ivec2_from_list(ComputeGraph& graph, ValueRef vref) { graph.get_val(vref).toIntList(), /*reverse = */ true); } +api::utils::ivec2 make_ivec2_kernel_size( + ComputeGraph& graph, + const ValueRef weight, + const bool kernel_size_only) { + if (kernel_size_only) { + return make_ivec2_from_list(graph, weight); + } else { + const auto weight_sizes = graph.get_val(weight).toTensorRef().sizes; + return api::utils::make_ivec2({weight_sizes.at(3), weight_sizes.at(2)}); + } +} + KernelParams create_kernel_params( ComputeGraph& graph, - const ValueRef kernel_size, + const ValueRef weight, + const bool kernel_size_only, const ValueRef stride, const ValueRef padding, const ValueRef dilation) { return { - make_ivec2_from_list(graph, kernel_size), + make_ivec2_kernel_size(graph, weight, kernel_size_only), make_ivec2_from_list(graph, stride), make_ivec2_from_list(graph, padding), make_ivec2_from_list(graph, dilation), @@ -49,7 +62,8 @@ int64_t calc_out_size( std::vector calc_out_sizes_hw( ComputeGraph& graph, const std::vector& in_sizes, - const ValueRef kernel_size, + const ValueRef weight, + const bool kernel_size_only, const ValueRef stride, const ValueRef padding, const ValueRef dilation, @@ -57,11 +71,13 @@ std::vector calc_out_sizes_hw( const int64_t ndim = in_sizes.size(); std::vector out_sizes(2); - const auto kernel_vec = make_ivec2_from_list(graph, kernel_size); + const auto kernel_vec = + make_ivec2_kernel_size(graph, weight, kernel_size_only); const auto stride_vec = make_ivec2_from_list(graph, stride); const auto padding_vec = make_ivec2_from_list(graph, padding); const auto dilation_vec = make_ivec2_from_list(graph, dilation); - const bool ceil_mode_val = graph.get_val(ceil_mode).toBool(); + const bool ceil_mode_val = + ceil_mode == kDummyValueRef ? false : graph.get_val(ceil_mode).toBool(); // Height out_sizes.at(0) = calc_out_size( diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h index b5e946e9413..923b3d8fd74 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h @@ -25,7 +25,8 @@ struct KernelParams final { KernelParams create_kernel_params( ComputeGraph& graph, - const ValueRef kernel_size, + const ValueRef weight, + const bool kernel_size_only, const ValueRef stride, const ValueRef padding, const ValueRef dilation); @@ -33,10 +34,11 @@ KernelParams create_kernel_params( std::vector calc_out_sizes_hw( ComputeGraph& graph, const std::vector& in_sizes, - const ValueRef kernel_size, + const ValueRef weight, + const bool kernel_size_only, const ValueRef stride, const ValueRef padding, const ValueRef dilation, - const ValueRef ceil_mode); + const ValueRef ceil_mode = kDummyValueRef); } // namespace vkcompute diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index d90cfad7bbe..d305fd19663 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -496,3 +496,30 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_conv2d(self): + class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=6, + out_channels=8, + kernel_size=(3, 3), + padding=(2, 3), + stride=(1, 2), + dilation=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + conv2d_module = Conv2dModule() + sample_inputs = (torch.randn(size=(1, 6, 40, 50), dtype=torch.float32),) + + self.lower_module_and_test_output( + conv2d_module, + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index a4e3b2acb29..0f0edafe75a 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -54,6 +54,40 @@ void record_image_to_nchw_op( v_src.cpu_sizes_ubo()->buffer()); } +void record_conv2d_prepack_weights_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst, + const std::vector& original_sizes, + const std::vector& padded_sizes) { + api::PipelineBarrier pipeline_barrier{}; + + std::stringstream kernel_name; + kernel_name << "conv2d_prepack_weights"; + apply_dtype_suffix(kernel_name, v_dst); + api::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name.str()); + + api::UniformParamsBuffer original_sizes_ubo( + context, api::utils::make_ivec4(original_sizes, /*reverse = */ true)); + api::UniformParamsBuffer padded_sizes_ubo( + context, api::utils::make_ivec2(padded_sizes, /*reverse = */ true)); + + context->submit_compute_job( + shader, + pipeline_barrier, + v_dst.virtual_extents(), + adaptive_work_group_size(v_dst.virtual_extents()), + VK_NULL_HANDLE, + v_dst.image( + pipeline_barrier, + api::PipelineStage::COMPUTE, + api::MemoryAccessType::WRITE), + src_buffer, + v_dst.gpu_sizes_ubo()->buffer(), + original_sizes_ubo.buffer(), + padded_sizes_ubo.buffer()); +} + void record_binary_op( api::Context* const context, const std::string& op_name, diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 8dcba015520..2d7d0b0746f 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -81,6 +81,13 @@ void record_image_to_nchw_op( vTensor& v_src, api::VulkanBuffer& dst_buffer); +void record_conv2d_prepack_weights_op( + api::Context* const context, + api::VulkanBuffer& src_buffer, + vTensor& v_dst, + const std::vector& original_sizes, + const std::vector& padded_sizes); + void record_binary_op( api::Context* const context, const std::string& op_name, diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 3a67e21fb44..c8e58c25cd2 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1172,3 +1172,57 @@ TEST(VulkanComputeGraphOpsTest, max_pool2d_smoke_test) { /*base_val = */ 10.0f, kernel); } + +TEST(VulkanComputeGraphOpsTest, conv2d_prepack_test) { + const auto original_sizes = std::vector{2, 3, 1, 2}; + const auto padded_sizes = std::vector{4, 4}; + const auto gpu_sizes = std::vector{4, 1, 8}; + + vTensor vten = vTensor( + api::context(), + gpu_sizes, + api::kFloat, + api::StorageType::TEXTURE_2D, + api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); + + // Create and fill input staging buffer + const int64_t in_numel = api::utils::multiply_integers(original_sizes); + api::StorageBuffer staging_buffer_in(api::context(), api::kFloat, in_numel); + + std::vector data_in(in_numel); + for (int i = 0; i < in_numel; i++) { + data_in[i] = i + 1; + } + copy_ptr_to_staging( + data_in.data(), staging_buffer_in, sizeof(float) * in_numel); + + // Output staging buffer + const int64_t out_numel = + padded_sizes[0] * padded_sizes[1] * original_sizes[2] * original_sizes[3]; + api::StorageBuffer staging_buffer_out(api::context(), api::kFloat, out_numel); + + // Copy data in and out of the tensor + record_conv2d_prepack_weights_op( + api::context(), + staging_buffer_in.buffer(), + vten, + original_sizes, + padded_sizes); + record_image_to_nchw_op(api::context(), vten, staging_buffer_out.buffer()); + + // Execute command buffer + submit_to_gpu(); + + // Extract data from output staging buffer + std::vector data_out(out_numel); + copy_staging_to_ptr( + staging_buffer_out, data_out.data(), sizeof(float) * out_numel); + + // Check data matches results copied from ATen-VK + std::vector data_out_expected = {1, 3, 5, 0, 2, 4, 6, 0, 7, 9, 11, + 0, 8, 10, 12, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + for (int i = 0; i < vten.numel(); i++) { + CHECK_VALUE(data_out, i, data_out_expected[i]); + } +}