Skip to content

Commit 0385230

Browse files
committed
Update on "introduce triton sdpa kernel to cuda backend"
**Introduce Triton SDPA Kernel to CUDA Backend** This diff introduces a Triton-optimized implementation of scaled dot-product attention (SDPA) kernel to the CUDA backend. The new kernel is designed to replace the default Edge SDPA operator during graph transformation to accelerate the model inference and get rid of sdpa decomposition. **Changes** * Added a new file `sdpa.py` to `fbcode/executorch/backends/cuda/triton/kernels` and `fbcode/executorch/backends/cuda/triton/kernels` directories, which contains the Triton-optimized SDPA kernel implementation. * Added a new file `__init__.py` to `fbcode/executorch/backends/cuda/triton/replacement_pass`, which replaces the given existing edge ops with target triton kernels. * Added tests for sdpa exporting with triton kernel. Without the triton kernel, sdpa model can not be exported. **Purpose** The purpose of this diff is to provide a high-performance SDPA kernel for the CUDA backend, which can be used to accelerate attention-based models on NVIDIA GPUs. Differential Revision: [D87259044](https://our.internmc.facebook.com/intern/diff/D87259044/) [ghstack-poisoned]
2 parents 1023d93 + 13d1f3a commit 0385230

File tree

12 files changed

+376
-104
lines changed

12 files changed

+376
-104
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -69,50 +69,46 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool:
6969

7070

7171
@register_cadence_pass(CadencePassAttribute(opt_level=0))
72-
class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
72+
class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface):
7373
"""
7474
A where op with a logical_not and a boolean tensor can be replaced
7575
by a where op with flipped inputs and the initial boolean tensor.
7676
"""
7777

78-
def replace_logical_nop_where_with_where(
79-
self, graph_module: torch.fx.GraphModule
80-
) -> None:
81-
graph = graph_module.graph
82-
for node in graph.nodes:
83-
# We are only interested in where nodes
84-
if node.target != exir_ops.edge.aten.where.self:
85-
continue
78+
@property
79+
def targets(self) -> list[EdgeOpOverload]:
80+
return [exir_ops.edge.aten.where.self]
8681

87-
# If the third arg is not a logical_not, bail.
88-
if node.args[0].target != exir_ops.edge.aten.logical_not.default:
89-
continue
82+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
83+
# If the first arg is not a logical_not, bail.
84+
if not isinstance(node.args[0], torch.fx.Node):
85+
return False
9086

91-
# Get the third arg node and its input
92-
logical_not_node = node.args[0]
93-
logical_not_input_node = logical_not_node.args[0]
87+
logical_not_node = cast(torch.fx.Node, node.args[0])
88+
if logical_not_node.target != exir_ops.edge.aten.logical_not.default:
89+
return False
9490

95-
# If the logical_not input is not a boolean tensor, bail.
96-
if logical_not_input_node.meta["val"].dtype != torch.bool:
97-
continue
91+
# Get the first arg node and its input
92+
if not isinstance(logical_not_node.args[0], torch.fx.Node):
93+
return False
9894

99-
# Replace the where op with another one, flipping the inputs and using the boolean
100-
# tensor from logical_not.
101-
with graph.inserting_before(node):
102-
linear_node = graph.call_function(
103-
exir_ops.edge.aten.where.self,
104-
args=(logical_not_node.args[0], node.args[2], node.args[1]),
105-
)
106-
# Replace all the uses
107-
node.replace_all_uses_with(linear_node)
95+
logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0])
10896

109-
graph_module.recompile()
110-
graph_module.graph.eliminate_dead_code()
97+
# If the logical_not input is not a boolean tensor, bail.
98+
if logical_not_input_node.meta["val"].dtype != torch.bool:
99+
return False
111100

112-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
113-
self.replace_logical_nop_where_with_where(graph_module)
114-
result = super().call(graph_module)
115-
return result
101+
# Replace the where op with another one, flipping the inputs and using the boolean
102+
# tensor from logical_not.
103+
with node.graph.inserting_before(node):
104+
new_node = node.graph.call_function(
105+
exir_ops.edge.aten.where.self,
106+
args=(logical_not_input_node, node.args[2], node.args[1]),
107+
)
108+
new_node.meta = node.meta
109+
# Replace all the uses
110+
node.replace_all_uses_with(new_node)
111+
return True
116112

117113

118114
@register_cadence_pass(CadencePassAttribute(opt_level=0))

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ReplaceFunctionallyEquivalentOpTargets,
3434
ReplaceIm2RowWithViewPass,
3535
ReplaceLinearWithFullyConnectedOpPass,
36+
ReplaceLogicalNotBooleanWhereWithWherePass,
3637
ReplaceMatmulWithTransposedMatmulPass,
3738
ReplaceMMWithAddMMPass,
3839
ReplaceMulTensorWithMulAndFullOpsPass,
@@ -2053,3 +2054,114 @@ def test_replace_quantized_embedding(
20532054
),
20542055
1,
20552056
)
2057+
2058+
2059+
class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase):
2060+
"""Tests for the ReplaceLogicalNotBooleanWhereWithWherePass."""
2061+
2062+
def test_replace_where_with_logical_not_boolean(self) -> None:
2063+
"""Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x)."""
2064+
# Setup: Create a graph with where(logical_not(bool_cond), x, y)
2065+
builder = GraphBuilder()
2066+
bool_cond_ = torch.randn(4, 8) > 0
2067+
x_ = torch.randn(4, 8)
2068+
y_ = torch.randn(4, 8)
2069+
2070+
bool_cond = builder.placeholder("bool_cond", bool_cond_)
2071+
x = builder.placeholder("x", x_)
2072+
y = builder.placeholder("y", y_)
2073+
2074+
# Create logical_not node
2075+
logical_not = builder.call_operator(
2076+
op=exir_ops.edge.aten.logical_not.default,
2077+
args=(bool_cond,),
2078+
)
2079+
2080+
# Create where node using logical_not
2081+
where_node = builder.call_operator(
2082+
op=exir_ops.edge.aten.where.self,
2083+
args=(logical_not, x, y),
2084+
)
2085+
builder.output([where_node])
2086+
original_gm = builder.get_graph_module()
2087+
2088+
# Make a copy of the original graph before applying the pass
2089+
original_gm_copy = copy.deepcopy(original_gm)
2090+
2091+
# Execute: Apply the replacement pass
2092+
p = ReplaceLogicalNotBooleanWhereWithWherePass()
2093+
result = cast(PassResult, p(original_gm))
2094+
2095+
# Assert: Verify the pass modified the graph
2096+
self.assertTrue(result.modified)
2097+
graph_after_passes = result.graph_module
2098+
2099+
# Assert: Verify logical_not is removed (dead code elimination)
2100+
self.assertEqual(
2101+
count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default),
2102+
0,
2103+
)
2104+
2105+
# Assert: Verify where node still exists
2106+
self.assertEqual(
2107+
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
2108+
1,
2109+
)
2110+
2111+
# Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped)
2112+
where_nodes = list(
2113+
graph_after_passes.graph.find_nodes(
2114+
op="call_function", target=exir_ops.edge.aten.where.self
2115+
)
2116+
)
2117+
for node in where_nodes:
2118+
# First arg should be the original bool_cond (not the logical_not)
2119+
self.assertEqual(node.args[0].name, "bool_cond")
2120+
# Second and third args should be swapped (y, x instead of x, y)
2121+
self.assertEqual(node.args[1].name, "y")
2122+
self.assertEqual(node.args[2].name, "x")
2123+
2124+
# Assert: Verify outputs match exactly by running both graphs
2125+
validate(
2126+
original_gm_copy,
2127+
graph_after_passes,
2128+
(bool_cond_, x_, y_),
2129+
"ReplaceLogicalNotBooleanWhereWithWherePass",
2130+
)
2131+
2132+
def test_no_replacement_without_logical_not(self) -> None:
2133+
"""Test that the pass does NOT apply when there's no logical_not."""
2134+
# Setup: Create a graph with where(bool_cond, x, y) without logical_not
2135+
builder = GraphBuilder()
2136+
bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0)
2137+
x = builder.placeholder("x", torch.randn(4, 8))
2138+
y = builder.placeholder("y", torch.randn(4, 8))
2139+
2140+
# Create where node directly without logical_not
2141+
where_node = builder.call_operator(
2142+
op=exir_ops.edge.aten.where.self,
2143+
args=(bool_cond, x, y),
2144+
)
2145+
builder.output([where_node])
2146+
original_gm = builder.get_graph_module()
2147+
2148+
# Execute: Apply the replacement pass
2149+
p = ReplaceLogicalNotBooleanWhereWithWherePass()
2150+
result = cast(PassResult, p(original_gm))
2151+
2152+
# Assert: Verify the pass did NOT modify the graph
2153+
self.assertFalse(result.modified)
2154+
graph_after_passes = result.graph_module
2155+
2156+
# Assert: Verify where node still exists unchanged
2157+
self.assertEqual(
2158+
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
2159+
1,
2160+
)
2161+
2162+
for node in graph_after_passes.graph.find_nodes(
2163+
op="call_function", target=exir_ops.edge.aten.where.self
2164+
):
2165+
self.assertEqual(node.args[0].name, "bool_cond")
2166+
self.assertEqual(node.args[1].name, "x")
2167+
self.assertEqual(node.args[2].name, "y")

backends/cadence/runtime/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_op_names(program: et_schema.Program, execution_plan_id: int = 0) -> set[
4545
op_names |= get_op_names(
4646
deserialize_pte_binary(
4747
program.backend_delegate_data[delegate.processed.index].data
48-
)
48+
).program
4949
)
5050
return op_names
5151

backends/cuda/triton/kernels/sdpa.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def _validate_qkv_shapes(
7272
triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_stages=4, num_warps=4),
7373
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64}, num_stages=3, num_warps=4),
7474
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_stages=3, num_warps=4),
75+
triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_stages=1, num_warps=2),
76+
triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_stages=1, num_warps=2),
7577
],
7678
key=["L_Q", "L_KV", "HEAD_DIM"],
7779
)
@@ -348,8 +350,8 @@ def _sdpa_abstract(
348350
attn_mask: Optional[torch.Tensor] = None,
349351
dropout_p: float = 0.0,
350352
is_causal: bool = False,
351-
scale=None,
352-
enable_gqa=False,
353+
scale: float = 0.0,
354+
enable_gq: bool = False,
353355
) -> torch.Tensor:
354356
"""
355357
Abstract/fake implementation for torch.export.

backends/cuda/triton/replacement_pass.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,9 @@ class ReplaceEdgeOpWithTritonOpPass(PassBase):
3131
"""
3232
Pass to replace ATen operators with Triton kernels.
3333
34-
This pass scans the graph for ATen operators that have registered Triton
35-
replacements and replaces them with the optimized Triton implementations.
36-
37-
It automatically imports EDGE_TO_TRITON_KERNELS from cuda_backend.py.
34+
This pass scans the graph for Edge operators that have registered Triton
35+
replacements using EDGE_TO_TRITON_KERNELS and replaces them with the
36+
optimized Triton implementations.
3837
"""
3938

4039
def __init__(self):
@@ -73,7 +72,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
7372
# Recompile the graph module after modifications
7473
graph_module.recompile()
7574

76-
print(f"Replaced {self._replacement_count} nodes with Triton kernels")
75+
logger.info(f"Replaced {self._replacement_count} nodes with Triton kernels")
7776

7877
return PassResult(graph_module, modified)
7978

@@ -83,7 +82,6 @@ def _should_replace_node(self, node: Node) -> bool:
8382
8483
Args:
8584
node: The node to check
86-
EDGE_TO_TRITON_KERNELS: Mapping from edge ops to Triton kernels
8785
8886
Returns:
8987
True if the node should be replaced
@@ -101,7 +99,6 @@ def _replace_node_with_triton(self, graph_module: GraphModule, node: Node) -> No
10199
Args:
102100
graph_module: The graph module containing the node
103101
node: The node to replace
104-
EDGE_TO_TRITON_KERNELS: Mapping from edge ops to Triton kernels
105102
"""
106103
# Get the target operator (should be an exir_ops edge dialect op)
107104
target = node.target

backends/qualcomm/utils/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def dump_context_from_pte(pte_path) -> List[str]:
197197
with open(pte_path, "rb") as f:
198198
program_data = f.read()
199199

200-
program = deserialize_pte_binary(program_data)
200+
program = deserialize_pte_binary(program_data).program
201201

202202
ctx_path = os.path.dirname(pte_path)
203203
dummy_compiler_specs = generate_qnn_executorch_compiler_spec(

codegen/tools/gen_ops_def.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_operators(model_file: str) -> List[Operator]:
2323
print("Processing model file: ", model_file)
2424
with open(model_file, "rb") as f:
2525
flatbuffer = f.read()
26-
program = _deserialize_pte_binary(flatbuffer)
26+
program = _deserialize_pte_binary(flatbuffer).program
2727
print(f"Program loaded from model file: {model_file}")
2828
operators = program.execution_plan[0].operators
2929
return operators

examples/qualcomm/oss_scripts/llama/decoder_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def __init__( # noqa: C901
276276

277277
with open(pte_path, "rb") as f:
278278
program_data = f.read()
279-
program = deserialize_pte_binary(program_data)
279+
program = deserialize_pte_binary(program_data).program
280280

281281
# Retrieve vocab_size from get_metadata under static_llama that is passed to edge manager
282282
self.output_vocab_size = None

exir/_serialize/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88

99
from executorch.exir._serialize._program import (
1010
deserialize_pte_binary as _deserialize_pte_binary,
11+
PTEFile as _PTEFile,
1112
serialize_pte_binary as _serialize_pte_binary,
1213
)
1314

1415
# Internal APIs that should not be used outside of exir.
1516
__all__ = [
1617
"_deserialize_pte_binary",
1718
"_serialize_pte_binary",
19+
"_PTEFile",
1820
]

0 commit comments

Comments
 (0)