diff --git a/backends/vulkan/runtime/graph/ops/impl/Cat.cpp b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp new file mode 100644 index 00000000000..08363fa71e4 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Cat.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +void add_cat_default_node( + ComputeGraph& graph, + ValueRef in_list_ref, + ValueRef dim_ref, + ValueRef out) { + ValueListPtr input_list = graph.get_value_list(in_list_ref); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked)); + } + + int64_t dim = graph.extract_scalar(dim_ref); + vTensorPtr t_out = graph.get_tensor(out); + + NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim); + + // TODO: Find ways to factor out the similar code for width, height, and batch + if (nchw_dim == DimWidth) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + api::utils::ivec3 range = t_in->texture_limits(); + add_copy_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset.data[0] += range.data[0]; + } + + } else if (nchw_dim == DimHeight) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + api::utils::ivec3 range = t_in->texture_limits(); + add_copy_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset.data[1] += range.data[1]; + } + } else if (nchw_dim == DimBatch) { + api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false); + api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false); + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + api::utils::ivec3 range = t_in->texture_limits(); + add_copy_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset.data[2] += range.data[2]; + } + } else if (nchw_dim == DimChannel) { + int32_t src_offset = 0; + int32_t dst_offset = 0; + + for (ValueRef input_ref : *input_list) { + vTensorPtr t_in = graph.get_tensor(input_ref); + int32_t range = dim_at(t_in->sizes()); + add_copy_channel_offset_node( + graph, input_ref, range, src_offset, dst_offset, out); + dst_offset += range; + } + } else { + VK_THROW("Unexpected value of nchw_dim=", nchw_dim); + } +} + +void cat_default(ComputeGraph& graph, const std::vector& args) { + add_cat_default_node(graph, args[0], args[1], args[2]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.cat.default, cat_default); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp index d599a00c2eb..5ca4973e56f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Copy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Copy.cpp @@ -93,10 +93,23 @@ void add_copy_channel_offset_node( VK_CHECK_COND( dim_at(in_sizes) >= src_channel_offset + channel_range, - "Source channel plus range should be less than or equal to input tensor's channel size"); + "Src channel (", + src_channel_offset, + ") and range (", + channel_range, + ") should be less than or equal to input tensor's channel size (", + dim_at(in_sizes), + ")"); + VK_CHECK_COND( dim_at(out_sizes) >= dst_channel_offset + channel_range, - "Source channel and range should be less than or equal to input tensor's channel size"); + "Dst channel (", + dst_channel_offset, + ") and range (", + channel_range, + ") should be less than or equal to input tensor's channel size (", + dim_at(out_sizes), + ")"); VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative"); VK_CHECK_COND( diff --git a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h index 92eba407d83..e7b9a614e28 100644 --- a/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h +++ b/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h @@ -70,4 +70,50 @@ uint32_t dim_at(const vTensor& v_in) { return dim_at(v_in.sizes()); } +// A canonical way to represent dimensions as enum. Intended to use the same +// value as Dim4D for potential future refactoring. + +enum NchwDim { + DimWidth = 1, + DimHeight = 2, + DimChannel = 3, + DimBatch = 4, +}; + +/* This function return a NchwDim + * given a Tensor and a user provided dim. The reason for this normalization is + * that in the user tensor coordinate, it is using a "big-endian" mechanism when + * referring to a nchw dimension, in that dim=0 refers to the batch dimension in + * a 4d tensor but dim=0 reference to height in a 2d tensor. Despite in a common + * texture representation of channel packing, a 2d tensor has exactly the same + * layout as a 4d with the batch and channel size equals to 1. This function + * returns a canonical dimension to simplify dimension reasoning in the code. + * + */ + +inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) { + return static_cast(v_in.dim() - dim); +} + +inline std::ostream& operator<<(std::ostream& os, NchwDim nchw_dim) { + switch (nchw_dim) { + case DimWidth: + os << "DimWidth"; + break; + case DimHeight: + os << "DimHeight"; + break; + case DimChannel: + os << "DimChannel"; + break; + case DimBatch: + os << "DimBatch"; + break; + default: + os << "DimUnknown"; + break; + } + return os; +} + } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 2a100b92e38..f0659ad8232 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -428,6 +428,52 @@ def get_repeat_inputs(): return test_suite +def get_cat_inputs(): + # TensorList must be specified as list of tuples + test_suite = VkTestSuite( + [ + # Cat on Height + ([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2), + ([(S1, 3, 5), (S1, 4, 5)], 1), + ([(3, 5), (4, 5)], 0), + ([(3, 5), (4, 5), (1, 5)], 0), + ( + [ + (3, 5), + ], + 0, + ), + # Cat on Width + ([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3), + ([(S1, 5, 3), (S1, 5, 4)], 2), + ([(5, 3), (5, 4)], 1), + ([(5, 3), (5, 4), (5, 1)], 1), + ( + [ + (5, 4), + ], + 1, + ), + ([(5,), (6,)], 0), + # Cat on Batch + ([(S, S1, 5, 4), (S1, S1, 5, 4)], 0), + ([(S, XS, 5, 4), (S1, XS, 5, 4)], 0), + ([(S, S2, 5, 4), (S1, S2, 5, 4)], 0), + # Cat on Channel + ([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0), + ([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0), + ([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1), + ([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1), + ] + ) + test_suite.layouts = [ + "api::kChannelsPacked", + ] + test_suite.data_gen = "make_seq_tensor" + test_suite.dtypes = ["at::kFloat"] + return test_suite + + test_suites = { "aten.add.Tensor": get_binary_elementwise_inputs(), "aten.sub.Tensor": get_binary_elementwise_inputs(), @@ -447,4 +493,5 @@ def get_repeat_inputs(): "aten.unsqueeze_copy.default": get_unsqueeze_inputs(), "aten.clone.default": get_clone_inputs(), "aten.repeat.default": get_repeat_inputs(), + "aten.cat.default": get_cat_inputs(), } diff --git a/backends/vulkan/test/op_tests/generate_op_tests.py b/backends/vulkan/test/op_tests/generate_op_tests.py index ef4dc0af91a..71047ac6f49 100644 --- a/backends/vulkan/test/op_tests/generate_op_tests.py +++ b/backends/vulkan/test/op_tests/generate_op_tests.py @@ -16,6 +16,7 @@ TestSuite, TestSuiteGen, ) +from torchgen import local from torchgen.gen import parse_native_yaml, ParsedYaml from torchgen.model import DispatchKey, NativeFunction @@ -45,6 +46,9 @@ def process_test_suites( cpp_generator.add_suite(registry_name, f, op_test_suite) +@local.parametrize( + use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False +) def generate_cpp( native_functions_yaml_path: str, tags_path: str, output_dir: str ) -> None: diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index f0e5547b4fe..ac5e25fa596 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -12,6 +12,7 @@ AT_INT_ARRAY_REF, AT_SCALAR, AT_TENSOR, + AT_TENSOR_LIST, BOOL, CppTestFileGen, DOUBLE, @@ -28,6 +29,7 @@ THREE_TENSOR_TUPLE, TWO_TENSOR_TUPLE, ) + from torchgen.api import cpp from torchgen.api.types import CppSignatureGroup @@ -75,6 +77,8 @@ class ValueRef: ValueRefList = Union[ValueRef, List[ValueRef]] +InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST]) + class ComputeGraphGen: def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): @@ -114,7 +118,7 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): name=f"{arg.name}_ref", src_cpp_name=arg.name, src_cpp_type=cpp_type, - is_in=(cpp_type == AT_TENSOR), + is_in=(cpp_type in InableCppType), requires_prepack=requires_prepack, supports_prepack=supports_prepack, ) @@ -244,6 +248,25 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901 ret_str += f"{self.graph}{self.dot}add_scalar" ret_str += f"({ref.src_cpp_name}.value());\n" return ret_str + elif ref.src_cpp_type == AT_TENSOR_LIST: + assert ref.is_in, "AT_TENSOR_LIST must be an input" + # This logic is a bit convoluted. We need to create a IOValueRef for + # each tensor, to facilate staging. On the other hand, we will + # use the .value tensor to create a ValueList, which will be passed + # to the corresponding ops. + ret_str = f"std::vector {ref.name}_io_value_refs;\n" + ret_str += f"std::vector {ref.name}_value_refs;\n" + ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n" + ret_str += f" {cpp_type} io_value_ref = {self.graph}{self.dot}add_input_tensor(\n" + ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n" + ret_str += ( + f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n" + ) + ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n" + ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n" + ret_str += "}\n" + ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n" + return ret_str ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}" if ref.src_cpp_type == AT_TENSOR and not prepack: @@ -288,11 +311,16 @@ def create_op_call(self) -> str: for aten_arg in self.args: ref = self.refs[aten_arg.name] - op_create_code += ( - f"{ref.name}.value, " - if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out - else f"{ref.name}, " - ) + if ref.src_cpp_type == AT_TENSOR_LIST: + # Special case. Underlying tensors are input tensors, but the + # container itself is just a normal value. + op_create_code += f"{ref.name}, " + else: + op_create_code += ( + f"{ref.name}.value, " + if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out + else f"{ref.name}, " + ) op_create_code += "out_ref});\n" return op_create_code @@ -311,22 +339,46 @@ def set_output(self, ref: ValueRefList) -> str: def virtual_resize(self, ref: ValueRefList) -> str: assert isinstance(ref, ValueRef) - assert ref.src_cpp_type == AT_TENSOR and ref.is_in + assert ref.src_cpp_type in InableCppType and ref.is_in if self.prepack_ref(ref): return "" - ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)" - ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n" + + if ref.src_cpp_type == AT_TENSOR: + ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)" + ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n" + elif ref.src_cpp_type == AT_TENSOR_LIST: + ret_str = "" + ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" + ret_str += f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)" + ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n" + ret_str += "}\n" + else: + raise AssertionError(f"{ref.src_cpp_type} not expected") + return ret_str def copy_into_staging(self, ref: ValueRefList) -> str: assert isinstance(ref, ValueRef) - assert ref.src_cpp_type == AT_TENSOR and ref.is_in + assert ref.src_cpp_type in InableCppType and ref.is_in + if self.prepack_ref(ref): return "" - ret_str = f"{self.graph}{self.dot}copy_into_staging(" - ret_str += f"{ref.name}.staging, " - ret_str += f"{ref.src_cpp_name}.const_data_ptr(), " - ret_str += f"{ref.src_cpp_name}.numel());\n" + + if ref.src_cpp_type == AT_TENSOR: + ret_str = f"{self.graph}{self.dot}copy_into_staging(" + ret_str += f"{ref.name}.staging, " + ret_str += f"{ref.src_cpp_name}.const_data_ptr(), " + ret_str += f"{ref.src_cpp_name}.numel());\n" + elif ref.src_cpp_type == AT_TENSOR_LIST: + ret_str = "" + ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" + ret_str += f" {self.graph}{self.dot}copy_into_staging(" + ret_str += f"{ref.name}_io_value_refs[i].staging, " + ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), " + ret_str += f"{ref.src_cpp_name}[i].numel());\n" + ret_str += "}\n" + else: + raise AssertionError(f"{ref.src_cpp_type} not expected") return ret_str def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str: @@ -547,8 +599,10 @@ def gen_parameterization(self) -> str: if (!is_close && t1.numel() < 500) { std::cout << "reference: " << std::endl; print(t1, 150); + std::cout << std::endl; std::cout << "vulkan: " << std::endl; print(t2, 150); + std::cout << std::endl; } return is_close; } diff --git a/backends/vulkan/test/op_tests/utils/codegen_base.py b/backends/vulkan/test/op_tests/utils/codegen_base.py index d5feada1df8..e9fbe0b1b29 100644 --- a/backends/vulkan/test/op_tests/utils/codegen_base.py +++ b/backends/vulkan/test/op_tests/utils/codegen_base.py @@ -4,7 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List +import re +from typing import Any, List, Tuple from torchgen.api import cpp from torchgen.api.types import CppSignatureGroup @@ -17,6 +18,7 @@ AT_INT_ARRAY_REF = "at::IntArrayRef" AT_SCALAR = "at::Scalar" AT_TENSOR = "at::Tensor" +AT_TENSOR_LIST = "at::TensorList" BOOL = "bool" DOUBLE = "double" INT = "int64_t" @@ -57,8 +59,8 @@ class GeneratedOpsTest_{op_name} : public ::testing::Test {{ test_suite_template = """ TEST_P(GeneratedOpsTest_{op_name}, {case_name}) {{ - {create_ref_data} - {create_and_check_out} +{create_ref_data} +{create_and_check_out} }} """ @@ -97,6 +99,9 @@ def __init__(self, f: NativeFunction, test_suite: TestSuite): self.f, method=False, fallback_binding=self.f.manual_cpp_binding ).most_faithful_signature() + def gen_case_name_tuple(self, t: Tuple) -> str: + return "x".join([str(e) for e in t]) + def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: name_str = self.op_name if prepack: @@ -104,13 +109,15 @@ def gen_case_name(self, inputs: List[Any], prepack: bool = False) -> str: for arg_sizes_or_val in inputs: name_str += "_" if isinstance(arg_sizes_or_val, tuple): - for size in arg_sizes_or_val: - name_str += str(size) + "x" - name_str = name_str[:-1] + name_str += self.gen_case_name_tuple(arg_sizes_or_val) elif isinstance(arg_sizes_or_val, list): + lst = [] for size in arg_sizes_or_val: - name_str += str(size) + "c" - name_str = name_str[:-1] + if isinstance(size, tuple): + lst.append(self.gen_case_name_tuple(size)) + else: + lst.append(str(size)) + name_str += "c".join(lst) else: name_str += str(arg_sizes_or_val).replace(".", "p") @@ -122,6 +129,15 @@ def create_input_data(self, arg: Argument, data: Any) -> str: # noqa: C901 ctype = cpp.argumenttype_type(arg.type, mutable=arg.is_write, binds=arg.name) cpp_type = ctype.cpp_type(strip_ref=True) + # Short cut exit for TENSORLIST, because it needs multiple lines of + # construction, deviates from the rest. + if cpp_type == AT_TENSOR_LIST: + ret_str = f"std::vector<{AT_TENSOR}> tensor_vec;\n" + for elem in data: + ret_str += f"tensor_vec.emplace_back({self.suite_def.data_gen}({init_list_str(elem)}, test_dtype));\n" + ret_str += f"{cpp_type} {arg.name} = tensor_vec;\n" + return ret_str + "\n" + if cpp_type == AT_INT_ARRAY_REF: ret_str = f"std::vector {arg.name} = " else: @@ -169,6 +185,7 @@ def gen_create_ref_data(self, inputs: List[Any]) -> str: arg_data = get_or_return_default(arg, inputs, i) ref_code += self.create_input_data(arg, arg_data) + ref_code = re.sub(r"^", " ", ref_code, flags=re.M) return ref_code def gen_create_and_check_out(self, prepack=False) -> str: @@ -179,6 +196,7 @@ def gen_create_and_check_out(self, prepack=False) -> str: arg = binding.argument test_str += f"{arg.name}, " test_str = test_str[:-2] + ");" + test_str = re.sub(r"^", " ", test_str, flags=re.M) return test_str def gen_parameterization(self) -> str: