diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 3bc6ca52c60..7c1b800821f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -6,8 +6,11 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include +#include #include #include #include @@ -20,53 +23,51 @@ using api::utils::uvec4; void check_args( const vTensor& in, - const IntListPtr& permute_dims, + const std::vector& permute_dims, const vTensor& out) { VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked)); VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked)); - int64_t in_dim = in.dim(); + // This implementation doesn't not requires the input tensor to have the same + // dim size as the argument. The code will work as long as the input tensor's + // dim size is shorter than the permute dim array. In this case, the code + // assume size of 1 at the higher dimensions. + + int64_t out_dim = out.dim(); VK_CHECK_COND( - in_dim == permute_dims->size(), - "Input tensor dim size must match argument"); + out_dim == permute_dims.size(), + "Output tensor dim size must match argument"); } void add_permute_node( ComputeGraph& graph, ValueRef in, - ValueRef permute_dims_ref, + const std::vector& permute_dims, ValueRef out) { vTensorPtr t_in = graph.get_tensor(in); vTensorPtr t_out = graph.get_tensor(out); - IntListPtr permute_dims = graph.get_int_list(permute_dims_ref); - check_args(*t_in, permute_dims, *t_out); - uvec4 in_size{1u, 1u, 1u, 1u}, out_size{1u, 1u, 1u, 1u}; uvec4 out_dims{0u, 1u, 2u, 3u}; - int64_t in_dim = t_in->dim(); - - std::vector seen(in_dim); - for (int i = 0; i < in_dim; i++) { - int64_t permute_dim = (*permute_dims)[i]; + int64_t out_dim = t_out->dim(); + std::vector seen(out_dim); + for (int i = 0; i < t_out->dim(); i++) { + int64_t permute_dim = permute_dims[i]; VK_CHECK_COND( !seen[permute_dim], "Argument dim ", permute_dim, " is repeated"); seen[permute_dim] = true; - // Map to 4D tensor dims. - in_size.data[(4u - in_dim) + i] = t_in->size(i); - out_size.data[(4u - in_dim) + i] = t_in->size(permute_dim); - out_dims.data[(4u - in_dim) + i] = permute_dim + (4u - in_dim); + out_dims.data[(4u - out_dim) + i] = permute_dim + (4u - out_dim); } std::string kernel_name = "permute"; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); - uint32_t out_channels = out_size.data[1u]; - uint32_t in_channels = in_size.data[1u]; + uint32_t out_channels = dim_at(t_out->sizes()); + uint32_t in_channels = dim_at(t_in->sizes()); uint32_t out_c_aligned = api::utils::align_up(out_channels, 4u); uint32_t in_c_aligned = api::utils::align_up(in_channels, 4u); @@ -98,6 +99,16 @@ void add_permute_node( {})); } +void add_permute_node( + ComputeGraph& graph, + ValueRef in, + ValueRef permute_dims_ref, + ValueRef out) { + IntListPtr permute_dims = graph.get_int_list(permute_dims_ref); + + add_permute_node(graph, in, *permute_dims, out); +} + void permute(ComputeGraph& graph, const std::vector& args) { return add_permute_node(graph, args[0], args[1], args[2]); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.h b/backends/vulkan/runtime/graph/ops/impl/Permute.h new file mode 100644 index 00000000000..941a8896fe2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#include + +namespace vkcompute { + +void add_permute_node( + ComputeGraph& graph, + ValueRef in, + const std::vector& permute_dims, + ValueRef out); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp new file mode 100644 index 00000000000..c8ada796e8e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +void add_unsqueeze_node( + ComputeGraph& graph, + ValueRef in, + ValueRef dim_ref, + ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + + VK_CHECK_COND( + t_in->dim() < 4, "Cannot unsqueeze a tensor with more than 3 dimensions"); + + int64_t dim = graph.extract_scalar(dim_ref); + int64_t out_dim = t_out->dim(); + + std::vector permute_dims(out_dim); + for (int i = 1; i <= dim; i++) { + permute_dims[i - 1] = i; + } + permute_dims[dim] = 0; + + for (int i = dim + 1; i < out_dim; i++) { + permute_dims[i] = i; + } + + add_permute_node(graph, in, permute_dims, out); +} + +void unsqueeze(ComputeGraph& graph, const std::vector& args) { + return add_unsqueeze_node(graph, args[0], args[1], args[2]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.unsqueeze_copy.default, unsqueeze); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index bca58744dd8..9c4ed7dacd7 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -236,6 +236,7 @@ def get_permute_inputs(): ((9, 2), [1, 0]), ] ) + test_suite.layouts = ["api::kChannelsPacked"] return test_suite @@ -334,6 +335,32 @@ def get_slice_inputs(): return test_suite +def get_unsqueeze_inputs(): + test_suite = VkTestSuite( + [ + ((2, 3, 4), 0), + ((1, 1, 1), 0), + ((1, 1, 1), 1), + ((1, 1, 1), 2), + ((1, 1, 1), 3), + ((9, 9, 9), 0), + ((9, 9, 9), 1), + ((9, 9, 9), 2), + ((9, 9, 9), 3), + ((9, 9), 0), + ((9, 9), 1), + ((9, 9), 2), + ((9,), 0), + ((9,), 1), + ] + ) + test_suite.layouts = [ + "api::kChannelsPacked", + ] + test_suite.data_gen = "make_seq_tensor" + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -350,4 +377,5 @@ def get_slice_inputs(): "aten.permute_copy.default": get_permute_inputs(), "aten.view_copy.default": get_view_inputs(), "aten.slice_copy.Tensor": get_slice_inputs(), + "aten.unsqueeze_copy.default": get_unsqueeze_inputs(), }