Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/cuda/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ runtime.python_library(
name = "cuda_backend",
srcs = [
"cuda_backend.py",
"replace_slice_copy_with_slice.py",
"replace_view_copy_with_view.py",
],
visibility = [
"//executorch/...",
Expand Down
6 changes: 3 additions & 3 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from typing import Any, Dict, final, List, Optional, Set

import torch
from executorch.backends.cuda.replace_slice_copy_with_slice import (
ReplaceSliceCopyWithSlicePass,
from executorch.backends.cuda.replace_view_copy_with_view import (
ReplaceViewCopyWithViewPass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
Expand Down Expand Up @@ -124,7 +124,7 @@ def preprocess(
cuda_edge_program = move_to_device_pass(edge_program, "cuda")

# replace slice_copy with slice
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)

cuda_edge_program = cuda_edge_program.run_decompositions(
cuda_decomposition_table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,37 @@
from torch import fx


_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
_VIEW_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
torch.ops.aten.slice_copy.Tensor,
ops.edge.aten.slice_copy.Tensor,
torch.ops.aten.select_copy.int,
ops.edge.aten.select_copy.int,
)

_SLICE_TARGETS: Dict[
_VIEW_TARGETS: Dict[
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
] = {
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
torch.ops.aten.select_copy.int: torch.ops.aten.select.int,
ops.edge.aten.select_copy.int: ops.edge.aten.select.int,
}


class ReplaceSliceCopyWithSlicePass(ExportPass):
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""
class ReplaceViewCopyWithViewPass(ExportPass):
"""Replace non-mutated ``view_copy`` type of ops with ``view`` ops."""

def call(self, graph_module: fx.GraphModule) -> PassResult:
graph_changed = False

for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
if node.op != "call_function" or node.target not in _VIEW_COPY_TARGETS:
continue

if self._has_blocking_user(node, node.users.keys()):
continue

node.target = _SLICE_TARGETS[node.target]
node.target = _VIEW_TARGETS[node.target]
graph_changed = True

if graph_changed:
Expand Down
Loading