diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl index b7ce7996c48..b77f171dcc9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl @@ -21,78 +21,104 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in; layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in; -layout(set = 0, binding = 4) uniform PRECISION restrict Out_channels { - int data; -} -out_channels; - -layout(set = 0, binding = 5) uniform PRECISION restrict In_length { - int data; -} -in_length; - -layout(set = 0, binding = 6) uniform PRECISION restrict Kernel_size { - int data; -} -kernel_size; +layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits { + ivec3 out_limits; +}; + +layout(set = 0, binding = 5) uniform PRECISION restrict InSizes { + ivec4 in_sizes; +}; + +layout(set = 0, binding = 6) uniform PRECISION restrict Params { + int kernel_size; + int stride; + int padding; + int dilation; + int in_group_size; + int out_group_size; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -/* - * This implementation optimize for simplicity (and partially performance) for a - * (1, C, L) where C == groups. Hence we only focus on calculating the rolling - * kernel of the L dimension. - */ +// Let us define +// +// input = (N, in_C, in_L), +// output = (N, out_C, out_L), +// groups = G, +// kernel = K, +// +// which results in shapes +// +// weight = (out_C, in_C / G, K), +// bias = (out_C,). +// +// This implementation performs out_C shader invocations, where each invocation +// calculates the rolling kernel of the length dimension for each batch, i.e., +// computes out_L * N results. +// +// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4) +// shader invocations, where each invocation computes 1 result. But that +// performs worse. void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - // The global workgroup should have taken care of it. We only perform one - // work item for each 1d tensor on lengths - if (pos.x >= 1) { + if (any(greaterThanEqual(pos, out_limits))) { return; } - int c = pos.y; - if (c >= out_channels.data) { - return; - } - - // Assume n = 1, do not handle n > 1 case for now. - int n = pos.z; - if (n >= 1) { - return; - } - - vec4 bias = texelFetch(bias_in, ivec3(c, 0, 0), 0); - - for (int i = 0; i < in_length.data - kernel_size.data + 1; ++i) { - vec4 v = vec4(0); - for (int k = 0; k < kernel_size.data; ++k) { - const ivec3 in_pos = ivec3(i+k, c, 0); - const vec4 input_value = texelFetch(image_in, in_pos, 0); - - // Note that we are reading weight in the inner loop, this could be - // improved by moving it before the outer loop. Since the weight vector is - // contant for the entire call. - - // weight in input-space: (c, 0, k); - // notice that c is 4-packed. We need to mod 4 to get the actual weight. - const ivec3 w_pos = ivec3(k, 0, c / 4); - const vec4 weight = texelFetch(kernel_in, w_pos, 0); - - float w = weight.x; - if (c % 4 == 1) { - w = weight.y; - } else if (c % 4 == 2) { - w = weight.z; - } else if (c % 4 == 3) { - w = weight.w; + int in_length = in_sizes.x; + int batch_size = in_sizes.z; + + // "out_c" is the output's channel index where we write our result. + // Across shader invocations, this is the only value that varies. + int out_c = pos.y; + vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0); + + // "in_c" tracks the input's channel start index. + // We iterate over the input group that corresponds to the output group. + int c_start = (out_c / out_group_size) * in_group_size; + int c_end = c_start + in_group_size; + + // "in_l" tracks the input's length start index for our input-kernel overlay + // region. + int l_start = -padding; + int l_end = in_length + padding - dilation * (kernel_size - 1); + + // Since the input/output tensors are channel-packed, which is along the + // batch dimension, we can batch-read/write four elements at a time. + for (int n = 0; n < batch_size; n += 4) { + // "out_l" tracks the output's length index where we write our result. + int out_l = 0; + + for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) { + vec4 sum = vec4(0); + + for (int in_c = c_start; in_c < c_end; ++in_c) { + // "k" tracks the kernel's index for our input-kernel computation. + // It reads out-of-bound zeros, but trying to avoid them complicates + // for-loop conditions, which results in worse performance. + for (int k = 0; k < kernel_size; k += 4) { + // Since the weight tensor is width-packed, which is along the length + // dimension, we can batch-read four elements at a time. + const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c); + const vec4 weight = texelFetch(kernel_in, w_pos, 0); + + const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4); + sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum); + + const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4); + sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum); + + const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4); + sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum); + + const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4); + sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum); + } } - v += w * input_value.x; + ivec3 out_pos = ivec3(out_l, out_c, n / 4); + imageStore(image_out, out_pos, sum + bias.x); } - - ivec3 out_pos = ivec3(i, c, 0); - imageStore(image_out, out_pos, vec4(v.x + bias.x, 0, 0, 0)); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 20d7c9256bb..d40352d2240 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -61,6 +61,11 @@ void resize_conv1d_node( vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr self = graph->get_tensor(args[1].refs[0]); TensorRefPtr weight_ref = graph->get_tref(extra_args[0]); + + int64_t stride_size = graph->get_int_list(extra_args[1])->at(0); + int64_t padding_size = graph->get_int_list(extra_args[2])->at(0); + int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0); + const std::vector& weight_sizes = weight_ref->sizes; const std::vector& in_sizes = self->sizes(); @@ -71,8 +76,9 @@ void resize_conv1d_node( int64_t in_length = in_sizes.at(2); new_out_sizes.at(0) = in_sizes.at(0); - new_out_sizes.at(1) = in_sizes.at(1); - new_out_sizes.at(2) = in_length - kernel_size + 1; + new_out_sizes.at(1) = weight_sizes.at(0); + new_out_sizes.at(2) = calc_out_size( + in_length, kernel_size, stride_size, padding_size, dilation_size, false); out->virtual_resize(new_out_sizes); } @@ -244,10 +250,6 @@ ValueRef prepack_weights( } void check_conv_args(const vTensor& in, const vTensor& out) { - if (in.sizes().at(0) > 1) { - VK_THROW( - "aten.convolution.default: input batch size > 1 is not supported yet!"); - } VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked)); VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked)); } @@ -260,7 +262,7 @@ struct Conv2dParams final { Conv2dParams create_conv2d_params( ComputeGraph& graph, const ValueRef weight, - const KernelParams& p, + const Kernel2dParams& p, const bool transposed) { const auto& overlay_region = api::utils::make_ivec2({ p.kernel_size.data[0] + @@ -275,7 +277,7 @@ Conv2dParams create_conv2d_params( return {overlay_region, in_group_size}; } -void check_conv2d_params(const KernelParams& p, const bool transposed) { +void check_conv2d_params(const Kernel2dParams& p, const bool transposed) { if (transposed) { if (p.dilation.data[0] > 1 || p.dilation.data[1] > 1) { VK_THROW( @@ -342,12 +344,15 @@ void add_conv2d_node( vTensorPtr t_in = graph.get_tensor(arg_in); vTensorPtr t_out = graph.get_tensor(out); + if (t_in->sizes().at(0) > 1) { + VK_THROW("conv2d: input batch size > 1 is not supported yet!"); + } check_conv_args(*t_in, *t_out); api::utils::uvec3 global_size = t_out->extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - KernelParams kernel_params = create_kernel_params( + Kernel2dParams kernel_params = create_kernel2d_params( graph, weight, /*kernel_size_only = */ false, @@ -395,8 +400,7 @@ void add_conv1d_node( const ValueRef groups, const ValueRef out) { ValueRef arg_in = prepack_if_tensor_ref(graph, in); - ValueRef arg_weight = - prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in)); + ValueRef arg_weight = prepack_if_tensor_ref(graph, weight, api::kWidthPacked); ValueRef arg_bias = prepack_biases( graph, bias, @@ -414,37 +418,29 @@ void add_conv1d_node( std::vector in_sizes = t_in->sizes(); std::vector weight_sizes = t_weight->sizes(); std::vector out_sizes = t_out->sizes(); - IntListPtr stride_sizes = graph.get_int_list(stride); - IntListPtr padding_sizes = graph.get_int_list(padding); - IntListPtr dilation_sizes = graph.get_int_list(dilation); - int64_t weight_out_channels = weight_sizes.at(0); - int64_t kernel_size = weight_sizes.at(2); - int64_t in_length = in_sizes.at(2); - - VK_CHECK_COND(in_sizes.size() == 3, "input must be a 3-dim tensor"); - VK_CHECK_COND(weight_sizes.size() == 3, "weight must be a 3-dim tensor"); - VK_CHECK_COND( - stride_sizes->size() == 1 && stride_sizes->at(0) == 1, - "stride must be 1"); - VK_CHECK_COND( - padding_sizes->size() == 1 && padding_sizes->at(0) == 0, - "padding must be 0"); - VK_CHECK_COND( - dilation_sizes->size() == 1 && dilation_sizes->at(0) == 1, - "dilation must be 1"); - VK_CHECK_COND( - groups_val == in_sizes.at(1), "groups must be equal to in_channels"); - VK_CHECK_COND( - groups_val == weight_sizes.at(0), - "groups must be equal to weight_sizes.at(0)"); - VK_CHECK_COND(weight_sizes.at(1) == 1, "weight_sizes.at(1) must be 1"); check_conv_args(*t_in, *t_out); - api::utils::uvec3 global_size = { - 1, static_cast(weight_out_channels), 1}; + int32_t in_channels = in_sizes.at(1); + int32_t out_channels = weight_sizes.at(0); + int32_t kernel_size = weight_sizes.at(2); + int32_t stride_size = graph.get_int_list(stride)->at(0); + int32_t padding_size = graph.get_int_list(padding)->at(0); + int32_t dilation_size = graph.get_int_list(dilation)->at(0); + int32_t in_group_size = static_cast(in_channels / groups_val); + int32_t out_group_size = static_cast(out_channels / groups_val); + + api::utils::uvec3 global_size = {1, static_cast(out_channels), 1}; api::utils::uvec3 local_size = {1, 1, 1}; + Kernel1dParams kernel_params = { + kernel_size, + stride_size, + padding_size, + dilation_size, + in_group_size, + out_group_size}; + std::string kernel_name("conv1d"); kernel_name.reserve(kShaderNameReserve); @@ -460,15 +456,15 @@ void add_conv1d_node( {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, // Shader params buffers { - graph.create_params_buffer(weight_out_channels), - graph.create_params_buffer(in_length), - graph.create_params_buffer(kernel_size), + t_out->texture_limits_ubo(), + t_in->sizes_ubo(), + graph.create_params_buffer(kernel_params), }, // Specialization Constants {}, // Resizing Logic resize_conv1d_node, - {weight})); + {weight, stride, padding, dilation})); } void conv(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 1a8a258627e..87aed6e273f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -76,7 +76,7 @@ void add_max_pool2d_node( std::string kernel_name("max_pool2d"); add_dtype_suffix(kernel_name, *t_out); - KernelParams kernel_params = create_kernel_params( + Kernel2dParams kernel_params = create_kernel2d_params( graph, kernel_size, /*kernel_size_only = */ true, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp index d342c4521f6..6b823fe30cd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp @@ -26,7 +26,7 @@ api::utils::ivec2 make_ivec2_kernel_size( } } -KernelParams create_kernel_params( +Kernel2dParams create_kernel2d_params( ComputeGraph& graph, const ValueRef weight, const bool kernel_size_only, diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h index fafb00e126c..eb0215bfd59 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h @@ -16,14 +16,23 @@ namespace vkcompute { -struct KernelParams final { +struct Kernel2dParams final { api::utils::ivec2 kernel_size; api::utils::ivec2 stride; api::utils::ivec2 padding; api::utils::ivec2 dilation; }; -KernelParams create_kernel_params( +struct Kernel1dParams final { + int kernel_size; + int stride; + int padding; + int dilation; + int in_group_size; + int out_group_size; +}; + +Kernel2dParams create_kernel2d_params( ComputeGraph& graph, const ValueRef weight, const bool kernel_size_only, @@ -31,6 +40,14 @@ KernelParams create_kernel_params( const ValueRef padding, const ValueRef dilation); +int64_t calc_out_size( + const int64_t in_size, + const int64_t kernel_size, + const int64_t stride, + const int64_t padding, + const int64_t dilation, + const bool ceil_mode); + std::vector calc_out_sizes_hw( ComputeGraph& graph, const std::vector& in_sizes, diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 3df9140ef43..bca58744dd8 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -135,6 +135,17 @@ def get_conv_inputs(): [0], 6, ), + ( + (2, 20, 30), + (10, 4, 6), + (10,), + [5], + [5], + [3], + False, + [0], + 5, + ), ( (1, 9, 11), (9, 1, 3), @@ -146,6 +157,17 @@ def get_conv_inputs(): [0], 9, ), + ( + (5, 15, 30), + (20, 3, 3), + None, + [3], + [5], + [7], + False, + [0], + 5, + ), ] ) return test_suite diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 1cce125a816..d802c4446e0 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -653,10 +653,13 @@ class Conv1dModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv1d( - in_channels=6, - out_channels=6, - kernel_size=3, - groups=6, + in_channels=20, + out_channels=10, + kernel_size=6, + stride=5, + padding=5, + dilation=3, + groups=5, bias=True, ) @@ -664,7 +667,7 @@ def forward(self, x): return self.conv(x) conv1d_module = Conv1dModule() - sample_inputs = (torch.randn(size=(1, 6, 7), dtype=torch.float32),) + sample_inputs = (torch.randn(size=(3, 20, 30), dtype=torch.float32),) self.lower_module_and_test_output( conv1d_module,