Skip to content

Commit 80981e7

Browse files
committed
introduce triton sdpa kernel to cuda backend
Pull Request resolved: #15859 **Introduce Triton SDPA Kernel to CUDA Backend** This diff introduces a kernel-generator (https://github.com/meta-pytorch/KernelAgent) driven, Triton-optimized implementation of scaled dot-product attention (SDPA) kernel to the CUDA backend. The new kernel is designed to replace the default Edge SDPA operator during graph transformation to accelerate the model inference and get rid of sdpa decomposition. **Changes** * Added a new file `sdpa.py` to `fbcode/executorch/backends/cuda/triton/kernels` and `fbcode/executorch/backends/cuda/triton/kernels` directories, which contains the Triton-optimized SDPA kernel implementation. * Added a new `fbcode/executorch/backends/cuda/triton/replacement_pass`, which replaces the given existing edge ops with target triton kernels. * Added tests for sdpa exporting with triton kernel. Without the triton kernel, sdpa model can not be exported. **Purpose** The purpose of this diff is to provide a high-performance SDPA kernel for the CUDA backend, which can be used to accelerate attention-based models on NVIDIA GPUs. ghstack-source-id: 324134561 @exported-using-ghexport Differential Revision: [D87259044](https://our.internmc.facebook.com/intern/diff/D87259044/)
1 parent 179a155 commit 80981e7

File tree

12 files changed

+622
-7
lines changed

12 files changed

+622
-7
lines changed

.github/workflows/cuda.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ jobs:
7171
strategy:
7272
fail-fast: false
7373
matrix:
74-
model: [linear, add, add_mul, resnet18, conv1d]
74+
model: [linear, add, add_mul, resnet18, conv1d, sdpa]
7575
with:
7676
timeout: 90
7777
runner: linux.g5.4xlarge.nvidia.gpu

backends/cuda/TARGETS

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ runtime.python_library(
1111
"//executorch/...",
1212
],
1313
deps = [
14+
":triton_replacement_pass",
1415
"//caffe2:torch",
1516
"//executorch/backends/aoti/passes:passes",
1617
"//executorch/exir/_serialize:lib",
@@ -32,3 +33,33 @@ runtime.python_library(
3233
"//executorch/backends/aoti:aoti_partitioner",
3334
],
3435
)
36+
37+
runtime.python_library(
38+
name = "triton_kernels",
39+
srcs = [
40+
"triton/kernels/__init__.py",
41+
"triton/kernels/sdpa.py",
42+
],
43+
visibility = [
44+
"//executorch/backends/cuda/...",
45+
],
46+
deps = [
47+
"//caffe2:torch",
48+
],
49+
)
50+
51+
runtime.python_library(
52+
name = "triton_replacement_pass",
53+
srcs = [
54+
"triton/__init__.py",
55+
"triton/replacement_pass.py",
56+
],
57+
visibility = [
58+
"//executorch/...",
59+
],
60+
deps = [
61+
":triton_kernels",
62+
"//caffe2:torch",
63+
"//executorch/exir/dialects:lib",
64+
],
65+
)

backends/cuda/cuda_backend.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
1717
ReplaceViewCopyWithViewPass,
1818
)
19+
20+
from executorch.backends.cuda.triton.replacement_pass import (
21+
ReplaceEdgeOpWithTritonOpPass,
22+
)
1923
from executorch.exir._serialize._named_data_store import NamedDataStore
2024
from executorch.exir._warnings import experimental
2125
from executorch.exir.backend.backend_details import (
@@ -27,7 +31,7 @@
2731
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
2832
from torch._inductor.decomposition import conv1d_to_conv2d
2933
from torch.export.passes import move_to_device_pass
30-
from torch.nn.attention import SDPBackend
34+
3135

3236
cuda_decomposition_table = {
3337
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
@@ -127,6 +131,9 @@ def preprocess( # noqa: C901
127131
# replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int
128132
ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module)
129133

134+
# Replace aten ops with triton ops
135+
ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module)
136+
130137
cuda_edge_program = cuda_edge_program.run_decompositions(
131138
cuda_decomposition_table
132139
)
@@ -188,11 +195,7 @@ def preprocess( # noqa: C901
188195
}
189196
)
190197

191-
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
192-
[
193-
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
194-
]
195-
), torch.no_grad():
198+
with collect_unsupported_fallback_kernels(), torch.no_grad():
196199
# torch._logging.set_logs(post_grad_graphs=True)
197200
# Here we should expect 1 so file and 1 weight blob in the same directory.
198201
paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]

backends/cuda/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ python_unittest_remote_gpu(
1919
"//executorch/exir:lib",
2020
"//executorch/exir/backend:backend_api",
2121
"//executorch/exir/backend:compile_spec_schema",
22+
"//executorch/examples/models/toy_model:toy_model",
2223
],
2324
keep_gpu_sections = True,
2425
)

backends/cuda/tests/test_cuda_export.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from executorch.backends.cuda.cuda_backend import CudaBackend
1212
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
13+
from executorch.examples.models.toy_model import SdpaModule
1314
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
1415
from torch.export import export
1516

@@ -270,3 +271,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
270271
# Test export
271272
edge_program_manager = self._export_to_cuda_with_lower(module, inputs)
272273
self.assertIsNotNone(edge_program_manager, "Conv1d operation export failed")
274+
275+
def test_sdpa_single_kernel(self):
276+
"""
277+
Test CUDA export for model containing single SDPA kernel.
278+
SDPA: Scaled Dot Product Attention
279+
"""
280+
281+
sdpa = SdpaModule()
282+
283+
# Test export
284+
edge_program_manager = self._export_to_cuda_with_lower(
285+
sdpa.get_eager_model(), sdpa.get_example_inputs()
286+
)
287+
self.assertIsNotNone(
288+
edge_program_manager,
289+
"SDPA single kernel operation export failed",
290+
)

backends/cuda/triton/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Import all kernels to ensure @triton_op decorators are executed
8+
# and ops are registered to torch.ops.triton namespace
9+
from executorch.backends.cuda.triton import kernels # noqa: F401
10+
11+
from executorch.backends.cuda.triton.replacement_pass import (
12+
ReplaceEdgeOpWithTritonOpPass,
13+
)
14+
15+
__all__ = [
16+
"ReplaceEdgeOpWithTritonOpPass",
17+
]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.backends.cuda.triton.kernels.sdpa import sdpa
8+
9+
__all__ = [
10+
"sdpa",
11+
]

0 commit comments

Comments
 (0)