diff --git a/backends/vulkan/runtime/api/Tensor.h b/backends/vulkan/runtime/api/Tensor.h index 53dbfecffe6..787e8111204 100644 --- a/backends/vulkan/runtime/api/Tensor.h +++ b/backends/vulkan/runtime/api/Tensor.h @@ -220,6 +220,10 @@ class vTensor final { */ const api::BufferBindInfo texture_limits_ubo(); + inline const api::utils::ivec3 texture_limits() const { + return texture_limits_.limits; + } + inline size_t numel() const { return api::utils::multiply_integers(sizes()); } diff --git a/backends/vulkan/runtime/api/Utils.h b/backends/vulkan/runtime/api/Utils.h index 3b0139b8efb..d12844bbf1e 100644 --- a/backends/vulkan/runtime/api/Utils.h +++ b/backends/vulkan/runtime/api/Utils.h @@ -262,12 +262,23 @@ inline std::ostream& operator<<(std::ostream& os, const uvec3& v) { return os; } +inline std::ostream& operator<<(std::ostream& os, const ivec3& v) { + os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")"; + return os; +} + inline std::ostream& operator<<(std::ostream& os, const uvec4& v) { os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", " << v.data[3u] << ")"; return os; } +inline std::ostream& operator<<(std::ostream& os, const ivec4& v) { + os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", " + << v.data[3u] << ")"; + return os; +} + // // std::vector Handling // @@ -298,6 +309,25 @@ inline ivec2 make_ivec2( } } +inline ivec3 make_ivec3( + const std::vector& ints, + bool reverse = false) { + VK_CHECK_COND(ints.size() == 3); + if (reverse) { + return { + safe_downcast(ints[2]), + safe_downcast(ints[1]), + safe_downcast(ints[0]), + }; + } else { + return { + safe_downcast(ints[0]), + safe_downcast(ints[1]), + safe_downcast(ints[2]), + }; + } +} + inline ivec4 make_ivec4( const std::vector& ints, bool reverse = false) { @@ -338,6 +368,13 @@ inline ivec3 make_ivec3(uvec3 ints) { safe_downcast(ints.data[2u])}; } +inline uvec3 make_uvec3(ivec3 ints) { + return { + safe_downcast(ints.data[0u]), + safe_downcast(ints.data[1u]), + safe_downcast(ints.data[2u])}; +} + /* * Given an vector of up to 4 uint64_t representing the sizes of a tensor, * constructs a uvec4 containing those elements in reverse order. diff --git a/backends/vulkan/runtime/graph/Logging.h b/backends/vulkan/runtime/graph/Logging.h index 5ee068100fd..447d52d16bd 100644 --- a/backends/vulkan/runtime/graph/Logging.h +++ b/backends/vulkan/runtime/graph/Logging.h @@ -34,6 +34,14 @@ inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec4& v) { return api::utils::operator<<(os, v); } +inline std::ostream& operator<<(std::ostream& os, const api::utils::ivec3& v) { + return api::utils::operator<<(os, v); +} + +inline std::ostream& operator<<(std::ostream& os, const api::utils::ivec4& v) { + return api::utils::operator<<(os, v); +} + template inline std::ostream& operator<<(std::ostream& os, const std::optional& opt) { os << "["; diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl new file mode 100644 index 00000000000..17b3e06e61e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.glsl @@ -0,0 +1,54 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutLimits { + ivec3 out_limits; +}; + +layout(set = 0, binding = 3) uniform PRECISION restrict InLimits { + ivec3 in_limits; +}; + + + +layout(set = 0, binding = 4) uniform PRECISION restrict CopyArgs { + ivec3 range; + int unused0; + ivec3 src_offset; + int unused1; + ivec3 dst_offset; + int unused2; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + const ivec3 out_pos = pos + dst_offset; + const ivec3 in_pos = pos + src_offset; + + if (any(greaterThanEqual(pos, range))) { + return; + } + + imageStore(image_out, out_pos, texelFetch(image_in, in_pos, 0)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml new file mode 100644 index 00000000000..4a31ba6bbca --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/copy_offset.yaml @@ -0,0 +1,10 @@ +copy_offset: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: copy_offset diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.glsl b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.glsl new file mode 100644 index 00000000000..42c7f86aea8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.glsl @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict RepeatArgs { + // With input_size (n, c_i, h, w) and repeat r + // out_size == (n, c_i * r, h, w) + ivec4 out_sizes; + ivec4 in_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = C_DIM; + + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + const ivec4 out_whcn = to_tensor_idx(out_pos, out_sizes, packed_dim); + + if (any(greaterThanEqual(out_whcn, out_sizes))) { + return; + } + + VEC4_T v; + // Loop over the 4 elements in texel, calculate the corresponding elem, and + // fetch. Not most efficient algorithm because likely we fetch same texel + // multiple times in this loop. + + for (int i=0; i<4;i++) { + ivec4 in_whcn = out_whcn; + in_whcn.z = (out_whcn.z + i) % in_sizes.z; + + ivec4 in_elem_pos = to_texture_elem_pos(in_whcn, in_sizes, packed_dim); + + v[i] = VEC4_T(texelFetch(image_in, in_elem_pos.xyz, 0))[in_elem_pos.w]; + } + + imageStore(image_out, out_pos, v); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml new file mode 100644 index 00000000000..4147e82965a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml @@ -0,0 +1,10 @@ +repeat_channel: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: repeat_channel diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp new file mode 100644 index 00000000000..0a5e20e4f7c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +namespace vkcompute { + +void add_copy_offset_node( + ComputeGraph& graph, + const ValueRef in, + const api::utils::ivec3& range, + const api::utils::ivec3& src_offset, + const api::utils::ivec3& dst_offset, + const ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + + VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked)); + + std::string kernel_name = "copy_offset"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = api::utils::make_uvec3(range); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + const struct Block final { + api::utils::ivec3 range; + int32_t unused0; + api::utils::ivec3 src_offset; + int32_t unused1; + api::utils::ivec3 dst_offset; + int32_t unused2; + } offset_params{ + range, + 0, + src_offset, + 0, + dst_offset, + 0, + }; + + auto shader = VK_KERNEL_FROM_STR(kernel_name); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, + // Parameter buffers + {t_out->texture_limits_ubo(), + t_in->texture_limits_ubo(), + graph.create_params_buffer(offset_params)}, + // Specialization Constants + {})); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.h b/backends/vulkan/runtime/graph/ops/impl/Copy.h new file mode 100644 index 00000000000..6e0deb6b74e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace vkcompute { + +void add_copy_offset_node( + ComputeGraph& graph, + const ValueRef in, + const api::utils::ivec3& range, + const api::utils::ivec3& src_offset, + const api::utils::ivec3& dst_offset, + const ValueRef out); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 7c1b800821f..14b77e3b451 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -21,6 +21,8 @@ using api::utils::ivec3; using api::utils::uvec2; using api::utils::uvec4; +namespace { + void check_args( const vTensor& in, const std::vector& permute_dims, @@ -39,6 +41,8 @@ void check_args( "Output tensor dim size must match argument"); } +} // namespace + void add_permute_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp new file mode 100644 index 00000000000..dedc7978ada --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Repeat.cpp @@ -0,0 +1,213 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include + +namespace vkcompute { + +namespace { + +void check_args( + const vTensor& in, + const std::vector& repeats, + const vTensor& out) { + VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked)); + + int64_t in_dim = in.dim(); + VK_CHECK_COND( + in_dim <= repeats.size(), + "Input tensor dim size must be not greater than the repeat argument's size"); + + VK_CHECK_COND( + dim_at(in.sizes()) * dim_at(repeats) == + dim_at(out.sizes()), + "Output's width doesn't match input's width * repeat count"); + + VK_CHECK_COND( + dim_at(in.sizes()) * dim_at(repeats) == + dim_at(out.sizes()), + "Output's height doesn't match input's height * repeat count"); + + VK_CHECK_COND( + dim_at(in.sizes()) * dim_at(repeats) == + dim_at(out.sizes()), + "Output's channel doesn't match input's channel * repeat count"); + + VK_CHECK_COND( + dim_at(in.sizes()) * dim_at(repeats) == + dim_at(out.sizes()), + "Output's batch doesn't match input's batch * repeat count"); +} + +} // namespace + +void add_repeat_channel_node( + ComputeGraph& graph, + ValueRef in, + int64_t repeat_channel, + ValueRef out, + api::utils::ivec3& running_range) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + + std::string kernel_name = "repeat_channel"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + const std::vector& in_sizes = t_in->sizes(); + + int32_t in_width = + api::utils::safe_downcast(dim_at(in_sizes)); + int32_t in_height = + api::utils::safe_downcast(dim_at(in_sizes)); + int32_t in_channel = + api::utils::safe_downcast(dim_at(in_sizes)); + int32_t in_batch = + api::utils::safe_downcast(dim_at(in_sizes)); + + int32_t out_channel = repeat_channel * in_channel; + + api::utils::ivec4 out_whcn_sizes{in_width, in_height, out_channel, in_batch}; + + api::utils::ivec4 in_whcn_sizes{in_width, in_height, in_channel, in_batch}; + + // Channel packed global work ids + running_range.data[2] = + out_whcn_sizes.data[3] * api::utils::div_up(out_whcn_sizes.data[2], 4); + api::utils::uvec3 global_size = api::utils::make_uvec3(running_range); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + const struct Block final { + api::utils::ivec4 out_sizes; + api::utils::ivec4 in_size; + } repeat_channel_args{ + out_whcn_sizes, + in_whcn_sizes, + }; + + auto shader = VK_KERNEL_FROM_STR(kernel_name); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, + // Parameter buffers + {graph.create_params_buffer(repeat_channel_args)}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); +} + +void add_repeat_node( + ComputeGraph& graph, + ValueRef in, + ValueRef repeats_ref, + ValueRef out) { + std::vector repeats = *(graph.get_int_list(repeats_ref)); + + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + check_args(*t_in, repeats, *t_out); + + // In this function, we expand the dimensions in the following order: + // 1. Channel + // 2. Width + // 3. Height + // 4. Batch + // After expanding a dimension, we will update the "running_range" since we + // will need to copy the "expanded" area. + + api::utils::ivec3 running_range = t_in->texture_limits(); + + const std::vector& in_sizes = t_in->sizes(); + + // Since we use channel packing, repeating the channel dimension is the most + // complicated and time-consuming, as we need to reason over misaligned + // channels. Hence we expand it first to minimize cost. Also, in this first + // dimension, we copy over the input texure to the output. In subsequent + // dimensions, we read and write from the same tensor. + + if (int64_t channel_repeat = dim_at(repeats); + channel_repeat == 1) { + // If no repeat, short-cut to a direct copy + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false); + + add_copy_offset_node(graph, in, running_range, src_offset, dst_offset, out); + + } else { + add_repeat_channel_node(graph, in, channel_repeat, out, running_range); + } + + // TODO: refactor width, height, and batch into a common helper function. + // Width + if (int64_t width_repeat = dim_at(repeats); width_repeat > 1) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (int i = 1; i < width_repeat; ++i) { + api::utils::ivec3 dst_offset = api::utils::make_ivec3( + {i * dim_at(in_sizes), 0, 0}, false); + + add_copy_offset_node( + graph, out, running_range, src_offset, dst_offset, out); + } + + running_range.data[0] = running_range.data[0] * width_repeat; + } + + // Height + if (int64_t height_repeat = dim_at(repeats); + height_repeat > 1) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (int i = 1; i < height_repeat; ++i) { + api::utils::ivec3 dst_offset = api::utils::make_ivec3( + {0, i * dim_at(in_sizes), 0}, false); + + add_copy_offset_node( + graph, out, running_range, src_offset, dst_offset, out); + } + + running_range.data[1] = running_range.data[1] * height_repeat; + } + + // Batch + if (int64_t batch_repeat = dim_at(repeats); batch_repeat > 1) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (int i = 1; i < batch_repeat; ++i) { + api::utils::ivec3 dst_offset = + api::utils::make_ivec3({0, 0, i * running_range.data[2]}, false); + + add_copy_offset_node( + graph, out, running_range, src_offset, dst_offset, out); + } + + running_range.data[2] = running_range.data[2] * batch_repeat; + } +} + +void repeat(ComputeGraph& graph, const std::vector& args) { + add_repeat_node(graph, args[0], args[1], args[2]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.repeat.default, repeat); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 8101c1d6fe2..2a100b92e38 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -377,6 +377,54 @@ def get_clone_inputs(): ((XS,),), ] ) + test_suite.layouts = [ + "api::kChannelsPacked", + ] + test_suite.data_gen = "make_seq_tensor" + return test_suite + + +def get_repeat_inputs(): + test_suite = VkTestSuite( + [ + # Repeat channels only (most challenging case) + ((3, XS, S), [2, 1, 1]), + ((7, XS, S), [4, 1, 1]), + ((1, 7, XS, S), [1, 4, 1, 1]), + ((3, 7, XS, S), [1, 4, 1, 1]), + # Repat channels with other dims + ((1, 7, XS, S), [1, 4, 1, 3]), + ((3, 7, XS, S), [1, 4, 1, 3]), + ((3, 7, XS, S), [1, 4, 3, 1]), + ((3, 7, XS, S), [1, 4, 3, 3]), + # Repeat Batch + ((3, 7, XS, S), [3, 4, 3, 3]), + ((3, 7, XS, S), [3, 1, 3, 3]), + # More other cases + ((3, 7, 1, 1), [1, 4, 1, 1]), + ((2, 3), [1, 4]), + ((2, 3), [4, 1]), + ((2, 3), [4, 4]), + ((S1, S2, S2), [1, 3, 1]), + ((S1, S2, S2), [1, 3, 3]), + ((S1, S2, S2), [3, 3, 1]), + ((S1, S2, S2), [3, 3, 3]), + ((S1, S2, S2, S2), [1, 1, 3, 1]), + ((S1, S2, S2, S2), [1, 1, 1, 3]), + ((S1, S2, S2, S2), [1, 1, 3, 3]), + ((S1, S2, S2, S2), [1, 3, 1, 3]), + ((S1, S2, S2, S2), [3, 3, 3, 3]), + ((S1, S2, S2, S2), [3, 3, 1, 1]), + # Expanding cases + ((2, 3), [3, 1, 4]), + ((2, 3), [3, 3, 2, 4]), + ] + ) + test_suite.layouts = [ + "api::kChannelsPacked", + ] + test_suite.data_gen = "make_seq_tensor" + test_suite.dtypes = ["at::kFloat"] return test_suite @@ -398,4 +446,5 @@ def get_clone_inputs(): "aten.slice_copy.Tensor": get_slice_inputs(), "aten.unsqueeze_copy.default": get_unsqueeze_inputs(), "aten.clone.default": get_clone_inputs(), + "aten.repeat.default": get_repeat_inputs(), } diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index 1e28519ebfb..f0e5547b4fe 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -31,8 +31,10 @@ from torchgen.api import cpp from torchgen.api.types import CppSignatureGroup -from torchgen.gen import generate_static_dispatch_backend_call -from torchgen.model import NativeFunction +from torchgen.gen import generate_static_dispatch_backend_call, translate_args + +from torchgen.gen_aoti_c_shim import gen_static_dispatch_backend_call_signature +from torchgen.model import NativeFunction, Variant ################################## ## Custom Test Suite Definition ## @@ -183,10 +185,23 @@ def create_aten_fn_call(self) -> str: func_call = generate_static_dispatch_backend_call( self.f_sig, self.f, TestSuiteGen.backend_key )[7:].replace("::cpu", "") + + return func_call + + def create_aten_method_call(self) -> str: + # For functions with only Method variant, we fallback to the function + # declared in MethodOperators.h. The method is declared as + # at::_ops::{name}::call(*), and ATEN_FN is a handly macro. + cpp_sig = gen_static_dispatch_backend_call_signature(self.f_sig, self.f) + exprs = translate_args(self.f_sig, cpp_sig) + func_call = f"ATEN_FN({self.f_sig.name()})({exprs});" return func_call def create_out_src(self) -> str: - return f"{self.out.cpp_type} out = " + self.create_aten_fn_call() + if Variant.function in self.f.variants: + return f"{self.out.cpp_type} out = " + self.create_aten_fn_call() + "\n" + else: + return f"{self.out.cpp_type} out = " + self.create_aten_method_call() + "\n" ## Graph code generation utils @@ -353,7 +368,6 @@ def check_graph_out(self, ref: ValueRefList) -> str: def gen_graph_build_code(self) -> str: graph_build = self.create_out_src() - for aten_arg in self.args: graph_build += self.create_value_for(self.refs[aten_arg.name])