Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Cat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>

#include <executorch/backends/vulkan/runtime/graph/ops/impl/Copy.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

namespace vkcompute {

void add_cat_default_node(
ComputeGraph& graph,
ValueRef in_list_ref,
ValueRef dim_ref,
ValueRef out) {
ValueListPtr input_list = graph.get_value_list(in_list_ref);

for (ValueRef input_ref : *input_list) {
vTensorPtr t_in = graph.get_tensor(input_ref);
VK_CHECK_COND(check_memory_layout_is(*t_in, api::kChannelsPacked));
}

int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
vTensorPtr t_out = graph.get_tensor(out);

NchwDim nchw_dim = normalize_to_nchw_dim(*t_out, dim);

// TODO: Find ways to factor out the similar code for width, height, and batch
if (nchw_dim == DimWidth) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

for (ValueRef input_ref : *input_list) {
vTensorPtr t_in = graph.get_tensor(input_ref);
api::utils::ivec3 range = t_in->texture_limits();
add_copy_offset_node(
graph, input_ref, range, src_offset, dst_offset, out);
dst_offset.data[0] += range.data[0];
}

} else if (nchw_dim == DimHeight) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

for (ValueRef input_ref : *input_list) {
vTensorPtr t_in = graph.get_tensor(input_ref);
api::utils::ivec3 range = t_in->texture_limits();
add_copy_offset_node(
graph, input_ref, range, src_offset, dst_offset, out);
dst_offset.data[1] += range.data[1];
}
} else if (nchw_dim == DimBatch) {
api::utils::ivec3 src_offset = api::utils::make_ivec3({0, 0, 0}, false);
api::utils::ivec3 dst_offset = api::utils::make_ivec3({0, 0, 0}, false);

for (ValueRef input_ref : *input_list) {
vTensorPtr t_in = graph.get_tensor(input_ref);
api::utils::ivec3 range = t_in->texture_limits();
add_copy_offset_node(
graph, input_ref, range, src_offset, dst_offset, out);
dst_offset.data[2] += range.data[2];
}
} else if (nchw_dim == DimChannel) {
int32_t src_offset = 0;
int32_t dst_offset = 0;

for (ValueRef input_ref : *input_list) {
vTensorPtr t_in = graph.get_tensor(input_ref);
int32_t range = dim_at<Dim4D::Channel>(t_in->sizes());
add_copy_channel_offset_node(
graph, input_ref, range, src_offset, dst_offset, out);
dst_offset += range;
}
} else {
VK_THROW("Unexpected value of nchw_dim=", nchw_dim);
}
}

void cat_default(ComputeGraph& graph, const std::vector<ValueRef>& args) {
add_cat_default_node(graph, args[0], args[1], args[2]);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.cat.default, cat_default);
}

} // namespace vkcompute
17 changes: 15 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,23 @@ void add_copy_channel_offset_node(

VK_CHECK_COND(
dim_at<Dim4D::Channel>(in_sizes) >= src_channel_offset + channel_range,
"Source channel plus range should be less than or equal to input tensor's channel size");
"Src channel (",
src_channel_offset,
") and range (",
channel_range,
") should be less than or equal to input tensor's channel size (",
dim_at<Dim4D::Channel>(in_sizes),
")");

VK_CHECK_COND(
dim_at<Dim4D::Channel>(out_sizes) >= dst_channel_offset + channel_range,
"Source channel and range should be less than or equal to input tensor's channel size");
"Dst channel (",
dst_channel_offset,
") and range (",
channel_range,
") should be less than or equal to input tensor's channel size (",
dim_at<Dim4D::Channel>(out_sizes),
")");

VK_CHECK_COND(channel_range >= 0, "Channel range must be non-negative");
VK_CHECK_COND(
Expand Down
46 changes: 46 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,50 @@ uint32_t dim_at(const vTensor& v_in) {
return dim_at<N>(v_in.sizes());
}

// A canonical way to represent dimensions as enum. Intended to use the same
// value as Dim4D for potential future refactoring.

enum NchwDim {
DimWidth = 1,
DimHeight = 2,
DimChannel = 3,
DimBatch = 4,
};

/* This function return a NchwDim
* given a Tensor and a user provided dim. The reason for this normalization is
* that in the user tensor coordinate, it is using a "big-endian" mechanism when
* referring to a nchw dimension, in that dim=0 refers to the batch dimension in
* a 4d tensor but dim=0 reference to height in a 2d tensor. Despite in a common
* texture representation of channel packing, a 2d tensor has exactly the same
* layout as a 4d with the batch and channel size equals to 1. This function
* returns a canonical dimension to simplify dimension reasoning in the code.
*
*/

inline NchwDim normalize_to_nchw_dim(const vTensor& v_in, int32_t dim) {
return static_cast<NchwDim>(v_in.dim() - dim);
}

inline std::ostream& operator<<(std::ostream& os, NchwDim nchw_dim) {
switch (nchw_dim) {
case DimWidth:
os << "DimWidth";
break;
case DimHeight:
os << "DimHeight";
break;
case DimChannel:
os << "DimChannel";
break;
case DimBatch:
os << "DimBatch";
break;
default:
os << "DimUnknown";
break;
}
return os;
}

} // namespace vkcompute
47 changes: 47 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,52 @@ def get_repeat_inputs():
return test_suite


def get_cat_inputs():
# TensorList must be specified as list of tuples
test_suite = VkTestSuite(
[
# Cat on Height
([(S1, S1, 3, 5), (S1, S1, 4, 5)], 2),
([(S1, 3, 5), (S1, 4, 5)], 1),
([(3, 5), (4, 5)], 0),
([(3, 5), (4, 5), (1, 5)], 0),
(
[
(3, 5),
],
0,
),
# Cat on Width
([(S1, S1, 5, 3), (S1, S1, 5, 4)], 3),
([(S1, 5, 3), (S1, 5, 4)], 2),
([(5, 3), (5, 4)], 1),
([(5, 3), (5, 4), (5, 1)], 1),
(
[
(5, 4),
],
1,
),
([(5,), (6,)], 0),
# Cat on Batch
([(S, S1, 5, 4), (S1, S1, 5, 4)], 0),
([(S, XS, 5, 4), (S1, XS, 5, 4)], 0),
([(S, S2, 5, 4), (S1, S2, 5, 4)], 0),
# Cat on Channel
([(S, 5, 4), (S1, 5, 4), (S2, 5, 4)], 0),
([(XS, 5, 4), (XS, 5, 4), (S2, 5, 4)], 0),
([(XS, S, 5, 4), (XS, S1, 5, 4), (XS, S2, 5, 4)], 1),
([(XS, XS, 5, 4), (XS, XS, 5, 4), (XS, S2, 5, 4)], 1),
]
)
test_suite.layouts = [
"api::kChannelsPacked",
]
test_suite.data_gen = "make_seq_tensor"
test_suite.dtypes = ["at::kFloat"]
return test_suite


test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
Expand All @@ -447,4 +493,5 @@ def get_repeat_inputs():
"aten.unsqueeze_copy.default": get_unsqueeze_inputs(),
"aten.clone.default": get_clone_inputs(),
"aten.repeat.default": get_repeat_inputs(),
"aten.cat.default": get_cat_inputs(),
}
4 changes: 4 additions & 0 deletions backends/vulkan/test/op_tests/generate_op_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TestSuite,
TestSuiteGen,
)
from torchgen import local

from torchgen.gen import parse_native_yaml, ParsedYaml
from torchgen.model import DispatchKey, NativeFunction
Expand Down Expand Up @@ -45,6 +46,9 @@ def process_test_suites(
cpp_generator.add_suite(registry_name, f, op_test_suite)


@local.parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
)
def generate_cpp(
native_functions_yaml_path: str, tags_path: str, output_dir: str
) -> None:
Expand Down
82 changes: 68 additions & 14 deletions backends/vulkan/test/op_tests/utils/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AT_INT_ARRAY_REF,
AT_SCALAR,
AT_TENSOR,
AT_TENSOR_LIST,
BOOL,
CppTestFileGen,
DOUBLE,
Expand All @@ -28,6 +29,7 @@
THREE_TENSOR_TUPLE,
TWO_TENSOR_TUPLE,
)

from torchgen.api import cpp
from torchgen.api.types import CppSignatureGroup

Expand Down Expand Up @@ -75,6 +77,8 @@ class ValueRef:

ValueRefList = Union[ValueRef, List[ValueRef]]

InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST])


class ComputeGraphGen:
def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
Expand Down Expand Up @@ -114,7 +118,7 @@ def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite):
name=f"{arg.name}_ref",
src_cpp_name=arg.name,
src_cpp_type=cpp_type,
is_in=(cpp_type == AT_TENSOR),
is_in=(cpp_type in InableCppType),
requires_prepack=requires_prepack,
supports_prepack=supports_prepack,
)
Expand Down Expand Up @@ -244,6 +248,25 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
ret_str += f"{self.graph}{self.dot}add_scalar<int64_t>"
ret_str += f"({ref.src_cpp_name}.value());\n"
return ret_str
elif ref.src_cpp_type == AT_TENSOR_LIST:
assert ref.is_in, "AT_TENSOR_LIST must be an input"
# This logic is a bit convoluted. We need to create a IOValueRef for
# each tensor, to facilate staging. On the other hand, we will
# use the .value tensor to create a ValueList, which will be passed
# to the corresponding ops.
ret_str = f"std::vector<IOValueRef> {ref.name}_io_value_refs;\n"
ret_str += f"std::vector<ValueRef> {ref.name}_value_refs;\n"
ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n"
ret_str += f" {cpp_type} io_value_ref = {self.graph}{self.dot}add_input_tensor(\n"
ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n"
ret_str += (
f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n"
)
ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n"
ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n"
ret_str += "}\n"
ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n"
return ret_str

ret_str = f"{cpp_type} {ref.name} = {self.graph}{self.dot}"
if ref.src_cpp_type == AT_TENSOR and not prepack:
Expand Down Expand Up @@ -288,11 +311,16 @@ def create_op_call(self) -> str:

for aten_arg in self.args:
ref = self.refs[aten_arg.name]
op_create_code += (
f"{ref.name}.value, "
if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out
else f"{ref.name}, "
)
if ref.src_cpp_type == AT_TENSOR_LIST:
# Special case. Underlying tensors are input tensors, but the
# container itself is just a normal value.
op_create_code += f"{ref.name}, "
else:
op_create_code += (
f"{ref.name}.value, "
if (ref.is_in and not self.prepack_ref(ref)) or ref.is_out
else f"{ref.name}, "
)

op_create_code += "out_ref});\n"
return op_create_code
Expand All @@ -311,22 +339,46 @@ def set_output(self, ref: ValueRefList) -> str:

def virtual_resize(self, ref: ValueRefList) -> str:
assert isinstance(ref, ValueRef)
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
assert ref.src_cpp_type in InableCppType and ref.is_in
if self.prepack_ref(ref):
return ""
ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)"
ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n"

if ref.src_cpp_type == AT_TENSOR:
ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)"
ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n"
elif ref.src_cpp_type == AT_TENSOR_LIST:
ret_str = ""
ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
ret_str += f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)"
ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n"
ret_str += "}\n"
else:
raise AssertionError(f"{ref.src_cpp_type} not expected")

return ret_str

def copy_into_staging(self, ref: ValueRefList) -> str:
assert isinstance(ref, ValueRef)
assert ref.src_cpp_type == AT_TENSOR and ref.is_in
assert ref.src_cpp_type in InableCppType and ref.is_in

if self.prepack_ref(ref):
return ""
ret_str = f"{self.graph}{self.dot}copy_into_staging("
ret_str += f"{ref.name}.staging, "
ret_str += f"{ref.src_cpp_name}.const_data_ptr(), "
ret_str += f"{ref.src_cpp_name}.numel());\n"

if ref.src_cpp_type == AT_TENSOR:
ret_str = f"{self.graph}{self.dot}copy_into_staging("
ret_str += f"{ref.name}.staging, "
ret_str += f"{ref.src_cpp_name}.const_data_ptr(), "
ret_str += f"{ref.src_cpp_name}.numel());\n"
elif ref.src_cpp_type == AT_TENSOR_LIST:
ret_str = ""
ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
ret_str += f" {self.graph}{self.dot}copy_into_staging("
ret_str += f"{ref.name}_io_value_refs[i].staging, "
ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), "
ret_str += f"{ref.src_cpp_name}[i].numel());\n"
ret_str += "}\n"
else:
raise AssertionError(f"{ref.src_cpp_type} not expected")
return ret_str

def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str:
Expand Down Expand Up @@ -547,8 +599,10 @@ def gen_parameterization(self) -> str:
if (!is_close && t1.numel() < 500) {
std::cout << "reference: " << std::endl;
print(t1, 150);
std::cout << std::endl;
std::cout << "vulkan: " << std::endl;
print(t2, 150);
std::cout << std::endl;
}
return is_close;
}
Expand Down
Loading