diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index 578f71964fc..aa2036f9e4e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -101,6 +101,7 @@ void permute(ComputeGraph& graph, const std::vector& args) { } REGISTER_OPERATORS { + VK_REGISTER_OP(aten.permute.default, permute); VK_REGISTER_OP(aten.permute_copy.default, permute); } diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 49c3188174b..4e556ce6fd5 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -206,5 +206,6 @@ def get_permute_inputs(): "aten.full.default": get_full_inputs(), "aten.select.int": get_select_int_inputs(), "aten.select_copy.int": get_select_int_inputs(), + "aten.permute.default": get_permute_inputs(), "aten.permute_copy.default": get_permute_inputs(), } diff --git a/backends/vulkan/test/op_tests/utils/codegen.py b/backends/vulkan/test/op_tests/utils/codegen.py index 0807c07ac89..4de7cf26ee4 100644 --- a/backends/vulkan/test/op_tests/utils/codegen.py +++ b/backends/vulkan/test/op_tests/utils/codegen.py @@ -310,7 +310,9 @@ def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str: ret_str += self.declare_vk_out_for(r) return ret_str - return f"at::Tensor vk_{ref.name} = at::empty_like({ref.src_cpp_name});\n" + ret_str = f"at::Tensor vk_{ref.name} = at::empty_like({ref.src_cpp_name})" + ret_str += ".contiguous();\n" + return ret_str def copy_from_staging(self, ref: ValueRefList) -> str: if isinstance(ref, list):