diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 2e39ed1bdfc..00a23aa920f 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -10,6 +10,8 @@ // @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName +#include + #include #include @@ -184,6 +186,15 @@ class ComputeGraph final { VK_THROW("Cannot extract scalar from Value with type ", value.type()); } + template + std::optional extract_optional_scalar(const ValueRef idx) { + if (val_is_none(idx)) { + return ::std::nullopt; + } else { + return extract_scalar(idx); + } + } + inline std::vector>& prepack_nodes() { return prepack_nodes_; } diff --git a/backends/vulkan/runtime/graph/Logging.h b/backends/vulkan/runtime/graph/Logging.h index 2c42b78fa5e..5ee068100fd 100644 --- a/backends/vulkan/runtime/graph/Logging.h +++ b/backends/vulkan/runtime/graph/Logging.h @@ -10,6 +10,7 @@ #include +#include #include #include @@ -33,4 +34,14 @@ inline std::ostream& operator<<(std::ostream& os, const api::utils::uvec4& v) { return api::utils::operator<<(os, v); } +template +inline std::ostream& operator<<(std::ostream& os, const std::optional& opt) { + os << "["; + if (opt) { + os << opt.value(); + } + os << "]"; + return os; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 7231003c51b..861f2fc45a8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -8,15 +8,15 @@ #define divup4(x) ((x + 3) / 4) -// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape -// Output: buffer_idx in the continuous nchw-buffer. +// Input: idx is a ivec4 user-level (w, h, c, n) coordinate, sizes is the tensor +// shape Output: buffer_idx in the continuous nchw-buffer. #define to_buffer_i(idx, sizes) \ (idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \ idx.w * sizes.z * sizes.y * sizes.x) // Inverse of to_buffer_i // Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape -// Output: ivec4 user-level coorindate +// Output: ivec4 user-level (w, h, c, n) coorindate #define from_buffer_i(buf_i, sizes) \ ivec4( \ buf_i % sizes.x, \ diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl new file mode 100644 index 00000000000..7b53474e678 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl @@ -0,0 +1,59 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict SliceArg { + int dim; + int offset; + int step; + // Used when dim=batch. Stride is the # of plances for each batch value. + int stride; +} +slice_arg; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + ivec3 in_pos = out_pos; + + int index = out_pos[slice_arg.dim] / slice_arg.stride; + int within_stride = out_pos[slice_arg.dim] % slice_arg.stride; + + in_pos[slice_arg.dim] = slice_arg.offset * slice_arg.stride + index * slice_arg.step * + slice_arg.stride + within_stride; + + imageStore(image_out, out_pos, texelFetch(image_in, in_pos, 0)); + +} + + diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.yaml b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.yaml new file mode 100644 index 00000000000..9e69b09a304 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.yaml @@ -0,0 +1,10 @@ +slice_batch_height_width: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: slice_batch_height_width diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl new file mode 100644 index 00000000000..5b116ec524b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl @@ -0,0 +1,85 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + + +#define to_tensor_idx to_tensor_idx_${PACKING} +#define to_texture_pos_elem to_texture_pos_elem_${PACKING} +#define get_packed_stride get_packed_stride_${PACKING} + + +layout(std430) buffer; + +#include "indexing_utils.h" + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + uvec4 data; +} +out_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes { + uvec4 out_cpu_sizes; +}; + +layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes { + uvec4 in_gpu_sizes; +}; + +layout(set = 0, binding = 5) uniform PRECISION restrict SliceArg { + int offset; + int step; +} +slice_arg; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data); + + if (any(greaterThanEqual(idx, out_sizes.data))) { + return; + } + + // We map the output pos using the buffer index. For each index in the texel, + // we calculate the source whcn-coordinate amended with offset-ed channel + // value. Then we calculate the actual texture position from the + // whcn-coordinate. + + const uint base_index = to_buffer_i(idx, out_cpu_sizes); + uvec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes); + + vec4 outex; + for (int i=0;i<4;i++) { + ivec4 user_coor = from_buffer_i(buf_indices[i], out_cpu_sizes); + + int in_channel = user_coor.z; + + ivec4 in_user_coor = user_coor; + in_user_coor.z = slice_arg.offset + in_channel * slice_arg.step; + + ivec4 in_pow_elem = to_texture_pos_elem_C_packed( + in_user_coor, + in_gpu_sizes); + + vec4 v = texelFetch(image_in, in_pow_elem.xyz, 0); + + outex[i] = v[in_pow_elem.w]; + } + imageStore(image_out, out_pos, outex); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.yaml new file mode 100644 index 00000000000..b5c189fb386 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.yaml @@ -0,0 +1,11 @@ +slice_channel: + parameter_names_with_default_values: + DTYPE: float + NDIM: 3 + generate_variant_forall: + DTYPE: + - VALUE: float + PACKING: + - VALUE: C_packed + shader_variants: + - NAME: slice_channel diff --git a/backends/vulkan/runtime/graph/ops/impl/Slice.cpp b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp new file mode 100644 index 00000000000..e67a061228d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include +#include + +namespace vkcompute { + +void add_slice_tensor_out_node( + ComputeGraph& graph, + ValueRef in, + ValueRef dim_ref, + ValueRef opt_start_ref, + ValueRef opt_end_ref, + ValueRef step_ref, + ValueRef out) { + vTensorPtr t_in = graph.get_tensor(in); + vTensorPtr t_out = graph.get_tensor(out); + + VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(*t_out, api::kChannelsPacked)); + + // Need normalize the dim + int64_t dim = graph.extract_scalar(dim_ref); + + VK_CHECK_COND( + -t_in->dim() <= dim && dim < t_in->dim(), + "dim must be in range of [-self.dim(), self.dim()), but current dim's value is ", + dim, + " and self.dim() = ", + t_in->dim()); + + dim = normalize(dim, t_in->dim()); + + // Create a dim value as in the underlying dim is 4-dimension. + int64_t nchw_dim = dim + (4 - t_in->dim()); + + std::optional opt_start = + graph.extract_optional_scalar(opt_start_ref); + std::optional opt_end = + graph.extract_optional_scalar(opt_end_ref); + int64_t step = graph.extract_scalar(step_ref); + + const auto in_sizes = t_in->sizes(); + const auto out_sizes = t_out->sizes(); + + int64_t start = opt_start.value_or(0); + int64_t end = opt_end.value_or(in_sizes[dim]); + + VK_CHECK_COND((0 <= start) && (start < in_sizes[dim])); + VK_CHECK_COND((0 <= end) && (end <= in_sizes[dim])); + + if (nchw_dim == 1) { + // slice by channel + std::string kernel_name = "slice_channel"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + add_memory_layout_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + const struct Block final { + int offset; + int step; + } params{ + static_cast(start), + static_cast(step), + }; + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out, api::MemoryAccessType::WRITE}, + {in, api::MemoryAccessType::READ}}, + {t_out->gpu_sizes_ubo(), + t_out->cpu_sizes_ubo(), + t_in->gpu_sizes_ubo(), + graph.create_params_buffer(params)})); + + } else { + // GPU's coordinate is in x, y, z + int64_t gpu_dim = -1; + int64_t stride = 1; + if (nchw_dim == 3) { + gpu_dim = 0; // width: x dimension in gpu + VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step)); + } else if (nchw_dim == 2) { + gpu_dim = 1; // height: y dimension + VK_CHECK_COND(out_sizes[dim] == (1 + (end - start - 1) / step)); + } else if (nchw_dim == 0) { + gpu_dim = 2; // batch: z dimension + + // Due to channel packing, each batch value is span over stride planes + int64_t n_channels = dim_at(in_sizes); + stride = api::utils::div_up(n_channels, 4ll); + } else { + VK_THROW("Unexpected ncwh_dim!"); + } + + std::string kernel_name = "slice_batch_height_width"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, *t_out); + + api::utils::uvec3 global_size = t_out->extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + const struct Block final { + int dim; + int offset; + int step; + int stride; + } params{ + static_cast(gpu_dim), + static_cast(start), + static_cast(step), + static_cast(stride), + }; + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + {{out, api::MemoryAccessType::WRITE}, + {in, api::MemoryAccessType::READ}}, + {t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)})); + } +} + +void slice_tensor_out(ComputeGraph& graph, const std::vector& args) { + return add_slice_tensor_out_node( + graph, + args[0], + args[1], // dim + args[2], // optional start + args[3], // optional end + args[4], // step + args[5]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.slice_copy.Tensor, slice_tensor_out); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index d5d0c5a6e56..4413709fad8 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -5,6 +5,8 @@ # LICENSE file in the root directory of this source tree. +from collections import namedtuple + from executorch.backends.vulkan.test.op_tests.utils.codegen import VkTestSuite @@ -221,6 +223,73 @@ def get_view_inputs(): return test_suite +def get_slice_inputs(): + Test = namedtuple("VkSliceTest", ["self", "dim", "start", "end", "step"]) + Test.__new__.__defaults__ = (None, 0, None, None, 1) + + # Slice by width and height + test_cases = [ + Test(self=[1, 1, 4, 10], dim=3, start=3), + Test(self=[1, 1, 4, 10], dim=3, start=3, step=2), + Test(self=[1, 1, 4, 10], dim=3, start=3, end=4, step=2), + Test(self=[1, 1, 4, 10], dim=2, start=3), + Test(self=[9, 9, 9, 9], dim=2, start=0, end=9, step=1), + Test(self=[9, 9, 9, 9], dim=2, start=1, end=8, step=1), + Test(self=[9, 9, 9, 9], dim=2, start=1, end=2, step=1), + Test(self=[9, 9, 9, 9], dim=3, start=1, end=5, step=1), + Test(self=[9, 9, 9, 9], dim=3, start=1, end=5, step=2), + Test(self=[9, 9, 9, 9], dim=-1, start=1, end=5, step=2), + Test(self=[9, 9, 9, 9], dim=-2, start=1, end=5, step=2), + Test(self=[9, 9, 9], dim=1, start=2, step=1), + Test(self=[9, 9, 9], dim=1, start=2, step=2), + Test(self=[9, 9, 9], dim=2, start=2, step=1), + Test(self=[9, 9, 9], dim=2, start=2, step=2), + Test(self=[9, 9], dim=0, start=2, step=1), + Test(self=[9, 9], dim=0, start=2, step=2), + Test(self=[9, 9], dim=1, start=2, step=1), + Test(self=[9, 9], dim=1, start=2, step=2), + ] + + # Slice by batch + test_cases += [ + Test(self=[6, 5, 3, 2], dim=0), + Test(self=[6, 5, 3, 2], dim=0, step=2), + Test(self=[13, 13, 3, 2], dim=0, step=2), + Test(self=[13, 13, 3, 2], dim=0, start=1, step=2), + Test(self=[13, 13, 3, 2], dim=0, start=1, step=5), + Test(self=[13, 13, 3, 2], dim=0, start=1, step=20), + Test(self=[13, 2, 3, 2], dim=0, start=1, step=2), + Test(self=[13, 2, 3, 2], dim=0, start=1, step=5), + Test(self=[13, 2, 3, 2], dim=0, start=1, step=20), + ] + + # Slice by channel + test_cases += [ + Test(self=[2, 5, 1, 10], dim=1), + Test(self=[2, 5, 1, 10], dim=1, start=1), + Test(self=[2, 5, 1, 10], dim=1, start=1, step=2), + Test(self=[5, 13, 1, 10], dim=1), + Test(self=[5, 13, 1, 10], dim=1, start=1), + Test(self=[5, 13, 1, 10], dim=1, start=1, step=2), + Test(self=[5, 13, 1, 10], dim=1, start=1, step=5), + Test(self=[5, 13, 1, 10], dim=1, start=1, step=20), + Test(self=[13, 1, 10], dim=0), + Test(self=[13, 1, 10], dim=0, start=1), + Test(self=[13, 1, 10], dim=0, start=1, step=2), + Test(self=[13, 1, 10], dim=0, start=1, step=5), + Test(self=[13, 1, 10], dim=0, start=1, step=20), + ] + + test_suite = VkTestSuite([tuple(tc) for tc in test_cases]) + + test_suite.dtypes = ["at::kFloat"] + test_suite.layouts = [ + "api::kChannelsPacked", + ] + test_suite.data_gen = "make_seq_tensor" + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -236,4 +305,5 @@ def get_view_inputs(): "aten.permute.default": get_permute_inputs(), "aten.permute_copy.default": get_permute_inputs(), "aten.view_copy.default": get_view_inputs(), + "aten.slice_copy.Tensor": get_slice_inputs(), } diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index 4de7cf26ee4..b1c08e6d0d8 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import re from dataclasses import dataclass - from typing import Any, List, Optional, Union from executorch.backends.vulkan.test.op_tests.utils.codegen_base import ( @@ -19,6 +19,7 @@ OPT_AT_TENSOR, OPT_BOOL, OPT_DEVICE, + OPT_INT64, OPT_LAYOUT, OPT_SCALARTYPE, TestSuite, @@ -43,6 +44,7 @@ def __init__(self, input_cases: List[Any]): super().__init__(input_cases) self.storage_types: List[str] = ["api::kTexture3D"] self.layouts: List[str] = ["api::kChannelsPacked"] + self.data_gen: str = "make_rand_tensor" ########################## @@ -219,6 +221,13 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()), " ret_str += f"{ref.src_cpp_name}->const_data_ptr()); \n" return ret_str + elif ref.src_cpp_type == OPT_INT64: + ret_str = f"{cpp_type} {ref.name} = " + ret_str += f"!{ref.src_cpp_name}.has_value() ? " + ret_str += f"{self.graph}{self.dot}add_none() : " + ret_str += f"{self.graph}{self.dot}add_scalar" + ret_str += f"({ref.src_cpp_name}.value());\n" + return ret_str ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}" if ref.src_cpp_type == AT_TENSOR and not prepack: @@ -383,14 +392,22 @@ def gen_conditional_skips(self) -> str: def gen_op_check_fn(self) -> str: op_name = self.f.func.name.unambiguous_name() - op_check_fn = self.gen_decl(f"check_{op_name}") + " {" + op_check_fn = self.gen_decl(f"check_{op_name}") + " {\n" if self.should_prepack: op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {" - op_check_fn += self.gen_conditional_skips() - op_check_fn += self.gen_graph_build_code() - op_check_fn += self.gen_graph_exec_code() - op_check_fn += self.check_graph_out(self.refs["out"]) - op_check_fn += "}\n" + + op_check_fn_body = "" + op_check_fn_body += self.gen_conditional_skips() + op_check_fn_body += self.gen_graph_build_code() + op_check_fn_body += self.gen_graph_exec_code() + op_check_fn_body += self.check_graph_out(self.refs["out"]) + + # Add two level of indent for readability + op_check_fn_body = re.sub(r"^", " ", op_check_fn_body, flags=re.M) + + op_check_fn += op_check_fn_body + "\n" + op_check_fn += " }\n" + return op_check_fn diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index ff3509db5ca..986526fbdcc 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -22,6 +22,7 @@ INT = "int64_t" OPT_AT_TENSOR = "::std::optional" OPT_BOOL = "::std::optional" +OPT_INT64 = "::std::optional" OPT_DEVICE = "::std::optional" OPT_LAYOUT = "::std::optional" OPT_SCALARTYPE = "::std::optional" @@ -105,21 +106,18 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: for size in arg_sizes_or_val: name_str += str(size) + "x" name_str = name_str[:-1] - # minus sign is a invalid char for test case. change to "n". - name_str = name_str.replace("-", "n") - elif isinstance(arg_sizes_or_val, list): for size in arg_sizes_or_val: name_str += str(size) + "c" name_str = name_str[:-1] - # minus sign is a invalid char for test case. change to "n". - name_str = name_str.replace("-", "n") - else: name_str += str(arg_sizes_or_val).replace(".", "p") + + # minus sign is a invalid char for test case. change to "n". + name_str = name_str.replace("-", "n") return name_str - def create_input_data(self, arg: Argument, data: Any) -> str: + def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901 ctype = cpp.argumenttype_type(arg.type, mutable=arg.is_write, binds=arg.name) cpp_type = ctype.cpp_type(strip_ref=True) @@ -129,7 +127,7 @@ def create_input_data(self, arg: Argument, data: Any) -> str: ret_str = f"{cpp_type} {arg.name} = " if cpp_type == AT_TENSOR: - ret_str += f"make_rand_tensor({init_list_str(data)}, test_dtype);" + ret_str += f"{self.suite_def.data_gen}({init_list_str(data)}, test_dtype);" elif cpp_type == OPT_AT_TENSOR: if str(data) == "None": ret_str += "std::nullopt;" @@ -145,6 +143,11 @@ def create_input_data(self, arg: Argument, data: Any) -> str: ret_str += f"{str(data).lower()};" elif cpp_type == DOUBLE: ret_str += f"{str(data).lower()};" + elif cpp_type == OPT_INT64: + if str(data) == "None": + ret_str += "std::nullopt;" + else: + ret_str += f"{str(data)};" elif ( cpp_type == OPT_SCALARTYPE or cpp_type == OPT_LAYOUT