From 44099685b9f8e6493f72f0970cafe9c0eb787130 Mon Sep 17 00:00:00 2001 From: Justin Yip Date: Mon, 29 Apr 2024 11:14:37 -0700 Subject: [PATCH] [ET-VK][14/n] Add operators to Partitioner Failure for Partial models: P1225514296 Differential Revision: [D56695929](https://our.internmc.facebook.com/intern/diff/D56695929/) [ghstack-poisoned] --- .../vulkan/partitioner/vulkan_partitioner.py | 8 ++ .../vulkan/runtime/graph/ops/impl/Split.cpp | 5 +- .../serialization/vulkan_graph_builder.py | 2 +- backends/vulkan/test/test_vulkan_delegate.py | 124 ++++++++++++++++++ 4 files changed, 136 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index e9ec9f2d84c..46dcb7c7268 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -27,6 +27,8 @@ class VulkanSupportedOperators(OperatorSupportBase): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op == "call_function": + print("AAA=>", node.target) supported = node.op == "call_function" and node.target in [ # Binary arithmetic operators exir_ops.edge.aten.add.Tensor, @@ -56,6 +58,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.view_copy.default, + # Copy-releated operators + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.split.Tensor, + exir_ops.edge.aten.slice_copy.Tensor, # Other operator.getitem, exir_ops.edge.aten.full.default, diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index 3b40871a791..2d218f722a2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -106,7 +106,7 @@ void add_split_with_sizes_default_node( add_split_with_sizes_default_node(graph, in, split_sizes, dim, out); } -void split_with_sizes_default( +void split_with_sizes_copy_default( ComputeGraph& graph, const std::vector& args) { add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]); @@ -134,7 +134,8 @@ void split_tensor(ComputeGraph& graph, const std::vector& args) { } REGISTER_OPERATORS { - VK_REGISTER_OP(aten.split_with_sizes.default, split_with_sizes_default); + VK_REGISTER_OP( + aten.split_with_sizes_copy.default, split_with_sizes_copy_default); VK_REGISTER_OP(aten.split.Tensor, split_tensor); } diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 9c12cb4a010..fe8e9246953 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -133,7 +133,7 @@ def create_node_value(self, node: Node) -> int: new_id = self.create_tensor_value(spec, constant_id) self.node_to_value_ids[node] = new_id return new_id - elif isinstance(spec, tuple): + elif isinstance(spec, tuple) or isinstance(spec, list): # Create a Value for each element in the tuple, wrap Values in a # ValueList, and map the Node to the ValueList id. new_id = self.create_value_list_value(spec) diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index a458fc1c24e..b498feb5957 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -80,6 +80,10 @@ def run_test(memory_layout): compile_options = { "memory_layout_override": memory_layout, } + + # At least model should run in eager mode. + eager_output = model(*sample_inputs) + program: ExportedProgram = export( model, sample_inputs, dynamic_shapes=dynamic_shapes ) @@ -798,3 +802,123 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def DISABLED_test_vulkan_backend_permute_copy(self): + # aten.permute_copy.default is not enabled yet in partitioner + class PermuteModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.permute(x, [3, 0, 2, 1]) + + sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),) + + self.lower_module_and_test_output( + PermuteModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_cat(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y, z, w): + return torch.cat([x, y, z, w], dim=1) + + sample_inputs = ( + torch.randn(size=(3, 6, 2, 7), dtype=torch.float32), + torch.randn(size=(3, 1, 2, 7), dtype=torch.float32), + torch.randn(size=(3, 9, 2, 7), dtype=torch.float32), + torch.randn(size=(3, 3, 2, 7), dtype=torch.float32), + ) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_slice(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x[:, 2:9:2, :] + + sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_split_with_sizes(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.split(x, (3, 6, 1, 3), dim=1) + + sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_split_tensor(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tensor_split(x, 2, dim=1) + + sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def test_vulkan_backend_clone(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.clone(x) + + sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + ) + + def DISABLED_test_vulkan_backend_t_default(self): + # aten.permute_copy.default is not enabled yet in partitioner + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + # torch.t is actually exported as aten::permute. + return torch.t(x) + + sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),) + + self.lower_module_and_test_output( + TestModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + )