Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
strategy:
fail-fast: false
matrix:
model: [linear, add, add_mul, resnet18, conv1d]
model: [linear, add, add_mul, resnet18, conv1d, sdpa]
with:
timeout: 90
runner: linux.g5.4xlarge.nvidia.gpu
Expand Down
31 changes: 31 additions & 0 deletions backends/cuda/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ runtime.python_library(
"//executorch/...",
],
deps = [
":triton_replacement_pass",
"//caffe2:torch",
"//executorch/backends/aoti/passes:passes",
"//executorch/exir/_serialize:lib",
Expand All @@ -32,3 +33,33 @@ runtime.python_library(
"//executorch/backends/aoti:aoti_partitioner",
],
)

runtime.python_library(
name = "triton_kernels",
srcs = [
"triton/kernels/__init__.py",
"triton/kernels/sdpa.py",
],
visibility = [
"//executorch/backends/cuda/...",
],
deps = [
"//caffe2:torch",
],
)

runtime.python_library(
name = "triton_replacement_pass",
srcs = [
"triton/__init__.py",
"triton/replacement_pass.py",
],
visibility = [
"//executorch/...",
],
deps = [
":triton_kernels",
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)
15 changes: 9 additions & 6 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
ReplaceViewCopyWithViewPass,
)

from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_details import (
Expand All @@ -27,7 +31,7 @@
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
from torch._inductor.decomposition import conv1d_to_conv2d
from torch.export.passes import move_to_device_pass
from torch.nn.attention import SDPBackend


cuda_decomposition_table = {
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
Expand Down Expand Up @@ -127,6 +131,9 @@ def preprocess( # noqa: C901
# replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)

# Replace aten ops with triton ops
ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module)

cuda_edge_program = cuda_edge_program.run_decompositions(
cuda_decomposition_table
)
Expand Down Expand Up @@ -188,11 +195,7 @@ def preprocess( # noqa: C901
}
)

with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
[
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
]
), torch.no_grad():
with collect_unsupported_fallback_kernels(), torch.no_grad():
# torch._logging.set_logs(post_grad_graphs=True)
# Here we should expect 1 so file and 1 weight blob in the same directory.
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
Expand Down
1 change: 1 addition & 0 deletions backends/cuda/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ python_unittest_remote_gpu(
"//executorch/exir:lib",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/examples/models/toy_model:toy_model",
],
keep_gpu_sections = True,
)
Expand Down
18 changes: 18 additions & 0 deletions backends/cuda/tests/test_cuda_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from executorch.backends.cuda.cuda_backend import CudaBackend
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
from executorch.examples.models.toy_model import SdpaModule
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
from torch.export import export

Expand Down Expand Up @@ -270,3 +271,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# Test export
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed")

def test_sdpa_single_kernel(self):
"""
Test CUDA export for model containing single SDPA kernel.
SDPA: Scaled Dot Product Attention
"""

sdpa = SdpaModule()

# Test export
edge_program_manager = self._export_to_cuda_with_lower(
sdpa.get_eager_model(), sdpa.get_example_inputs()
)
self.assertIsNotNone(
edge_program_manager,
"SDPA single kernel operation export failed",
)
17 changes: 17 additions & 0 deletions backends/cuda/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Import all kernels to ensure @triton_op decorators are executed
# and ops are registered to torch.ops.triton namespace
from executorch.backends.cuda.triton import kernels # noqa: F401

from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)

__all__ = [
"ReplaceEdgeOpWithTritonOpPass",
]
11 changes: 11 additions & 0 deletions backends/cuda/triton/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.cuda.triton.kernels.sdpa import sdpa

__all__ = [
"sdpa",
]
Loading
Loading