| 
22 | 22 | from executorch.backends.cadence.aot.remove_ops import (  | 
23 | 23 |     RemoveAliasCopyOpPass,  | 
24 | 24 |     RemoveBranchedQuantDequant,  | 
 | 25 | +    RemoveCatFromSliceCopyPass,  | 
25 | 26 |     RemoveCloneOpPass,  | 
26 | 27 |     RemoveContiguousOpPass,  | 
27 | 28 |     RemoveDetachCopyPass,  | 
@@ -741,3 +742,54 @@ def forward(self, x):  | 
741 | 742 |                 },  | 
742 | 743 |             )  | 
743 | 744 |         )  | 
 | 745 | + | 
 | 746 | +    def test_remove_cat_from_slice_copy_all_removal(self) -> None:  | 
 | 747 | +        class M(torch.nn.Module):  | 
 | 748 | +            def __init__(self):  | 
 | 749 | +                super().__init__()  | 
 | 750 | + | 
 | 751 | +            def forward(self, x, y):  | 
 | 752 | +                x1 = torch.cat((x, y), 0)  # (2, 4)  | 
 | 753 | +                return torch.slice_copy(x1, dim=0, start=0, end=1)  | 
 | 754 | + | 
 | 755 | +        inputs = tuple(torch.randn(2, 4) for _ in range(2))  | 
 | 756 | +        graph_module = export_to_edge(M(), inputs).exported_program().graph_module  | 
 | 757 | +        p = RemoveCatFromSliceCopyPass()  | 
 | 758 | +        graph_module = cast(PassResult, p(graph_module)).graph_module  | 
 | 759 | + | 
 | 760 | +        # Ensure both cat nodes were removed  | 
 | 761 | +        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)  | 
 | 762 | + | 
 | 763 | +    def test_remove_cat_from_slice_copy_no_removal(self) -> None:  | 
 | 764 | +        class M(torch.nn.Module):  | 
 | 765 | +            def __init__(self):  | 
 | 766 | +                super().__init__()  | 
 | 767 | + | 
 | 768 | +            def forward(self, x, y):  | 
 | 769 | +                x1 = torch.cat((x, y), 0)  # (2, 4)  | 
 | 770 | +                return torch.slice_copy(x1, dim=0, start=0, end=3)  | 
 | 771 | + | 
 | 772 | +        inputs = tuple(torch.randn(2, 4) for _ in range(2))  | 
 | 773 | +        graph_module = export_to_edge(M(), inputs).exported_program().graph_module  | 
 | 774 | +        p = RemoveCatFromSliceCopyPass()  | 
 | 775 | +        graph_module = cast(PassResult, p(graph_module)).graph_module  | 
 | 776 | + | 
 | 777 | +        # Ensure both cat nodes were removed  | 
 | 778 | +        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)  | 
 | 779 | + | 
 | 780 | +    def test_remove_cat_from_slice_copy_zero_range(self) -> None:  | 
 | 781 | +        class M(torch.nn.Module):  | 
 | 782 | +            def __init__(self):  | 
 | 783 | +                super().__init__()  | 
 | 784 | + | 
 | 785 | +            def forward(self, x, y):  | 
 | 786 | +                x1 = torch.cat((x, y), 0)  # (2, 4)  | 
 | 787 | +                return torch.slice_copy(x1, dim=0, start=0, end=0)  | 
 | 788 | + | 
 | 789 | +        inputs = tuple(torch.randn(2, 4) for _ in range(2))  | 
 | 790 | +        graph_module = export_to_edge(M(), inputs).exported_program().graph_module  | 
 | 791 | +        p = RemoveCatFromSliceCopyPass()  | 
 | 792 | +        graph_module = cast(PassResult, p(graph_module)).graph_module  | 
 | 793 | + | 
 | 794 | +        # Ensure both cat nodes were removed  | 
 | 795 | +        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)  | 
0 commit comments