|
16 | 16 | from executorch.backends.aoti.passes.replace_view_copy_with_view import ( |
17 | 17 | ReplaceViewCopyWithViewPass, |
18 | 18 | ) |
| 19 | + |
| 20 | +from executorch.backends.cuda.triton.replacement_pass import ( |
| 21 | + ReplaceEdgeOpWithTritonOpPass, |
| 22 | +) |
19 | 23 | from executorch.exir._serialize._named_data_store import NamedDataStore |
20 | 24 | from executorch.exir._warnings import experimental |
21 | 25 | from executorch.exir.backend.backend_details import ( |
|
27 | 31 | from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu |
28 | 32 | from torch._inductor.decomposition import conv1d_to_conv2d |
29 | 33 | from torch.export.passes import move_to_device_pass |
30 | | -from torch.nn.attention import SDPBackend |
| 34 | + |
31 | 35 |
|
32 | 36 | cuda_decomposition_table = { |
33 | 37 | torch.ops.aten.conv1d.default: conv1d_to_conv2d, |
@@ -127,6 +131,9 @@ def preprocess( # noqa: C901 |
127 | 131 | # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int |
128 | 132 | ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) |
129 | 133 |
|
| 134 | + # Replace aten ops with triton ops |
| 135 | + ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module) |
| 136 | + |
130 | 137 | cuda_edge_program = cuda_edge_program.run_decompositions( |
131 | 138 | cuda_decomposition_table |
132 | 139 | ) |
@@ -188,11 +195,7 @@ def preprocess( # noqa: C901 |
188 | 195 | } |
189 | 196 | ) |
190 | 197 |
|
191 | | - with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel( |
192 | | - [ |
193 | | - SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. |
194 | | - ] |
195 | | - ), torch.no_grad(): |
| 198 | + with collect_unsupported_fallback_kernels(), torch.no_grad(): |
196 | 199 | # torch._logging.set_logs(post_grad_graphs=True) |
197 | 200 | # Here we should expect 1 so file and 1 weight blob in the same directory. |
198 | 201 | paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] |
|
0 commit comments