diff --git a/backends/vulkan/runtime/graph/ops/impl/Select.cpp b/backends/vulkan/runtime/graph/ops/impl/Select.cpp index c66d98af4a2..e0412450ed6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Select.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Select.cpp @@ -127,6 +127,7 @@ void select_int(ComputeGraph& graph, const std::vector& args) { REGISTER_OPERATORS { VK_REGISTER_OP(aten.select.int, select_int); + VK_REGISTER_OP(aten.select_copy.int, select_int); } } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index fe8d7c25e01..7ebe7bbcffa 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -182,4 +182,5 @@ def get_select_int_inputs(): "aten.native_layer_norm.default": get_native_layer_norm_inputs(), "aten.full.default": get_full_inputs(), "aten.select.int": get_select_int_inputs(), + "aten.select_copy.int": get_select_int_inputs(), }