Skip to content

Commit f6654f2

Browse files
authored
[#5255][autodeploy] Update FuseAllreduceResidualRMSNorm to use pattern matcher utility; remove fuse_collective (#7545)
Signed-off-by: Frida Hou <[email protected]> Signed-off-by: Fridah-nv <[email protected]>
1 parent 744246d commit f6654f2

File tree

13 files changed

+161
-302
lines changed

13 files changed

+161
-302
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ transforms:
3232
stage: pattern_matcher
3333
match_repeat_kv:
3434
stage: pattern_matcher
35+
run_shape_prop: true
3536
match_eager_attention:
3637
stage: pattern_matcher
3738
match_grouped_attention:
@@ -111,13 +112,14 @@ transforms:
111112
enabled: true
112113
fuse_allreduce_residual_rmsnorm:
113114
stage: post_load_fusion
114-
fuse_collectives:
115-
stage: post_load_fusion
115+
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
116+
# check if we can fuse rmsnorm
116117
fuse_rmsnorm:
117118
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
118119
# check if we can fuse rmsnorm
119120
stage: post_load_fusion
120121
backend: flashinfer
122+
requires_shape_prop: true
121123
############################################################################################
122124
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
123125
############################################################################################

tensorrt_llm/_torch/auto_deploy/transform/interface.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import time
77
from abc import ABC, abstractmethod
8+
from contextlib import nullcontext
89
from enum import Enum
910
from functools import total_ordering, wraps
1011
from typing import Any, Callable, Dict, Mapping, Tuple, Type, Union, final
@@ -19,6 +20,7 @@
1920
canonicalize_graph,
2021
lift_to_meta,
2122
named_graphmodules,
23+
placeholders_on_meta,
2224
run_shape_prop,
2325
)
2426
from ..utils.logger import ad_logger
@@ -416,11 +418,13 @@ def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool,
416418
is_clean = info.is_clean
417419
has_valid_shapes = is_clean and info.has_valid_shapes
418420

421+
use_meta = isinstance(gm, GraphModule) and placeholders_on_meta(gm)
422+
419423
# check if run cleanup depending on the config and info
420424
if self.config.requires_shape_prop and not has_valid_shapes:
421425
self._log_info("running pre-cleanup with shape_prop")
422426
canonicalize_graph(gm)
423-
with lift_to_meta(gm):
427+
with lift_to_meta(gm) if use_meta else nullcontext():
424428
run_shape_prop(gm)
425429
is_clean = True
426430
has_valid_shapes = True
@@ -444,11 +448,13 @@ def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformIn
444448
if not self.config.run_graph_cleanup:
445449
return info
446450

451+
use_meta = isinstance(gm, GraphModule) and placeholders_on_meta(gm)
452+
447453
# check if run cleanup depending on the config and info
448454
if self.config.run_shape_prop and not (info.is_clean and info.has_valid_shapes):
449455
self._log_info("running post-cleanup with shape_prop")
450456
canonicalize_graph(gm)
451-
with lift_to_meta(gm):
457+
with lift_to_meta(gm) if use_meta else nullcontext():
452458
run_shape_prop(gm)
453459
elif self.config.run_graph_cleanup and not info.is_clean:
454460
self._log_info("running post-cleanup (no shape_prop)")

tensorrt_llm/_torch/auto_deploy/transform/library/attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,6 @@ def register_repeat_kv(patterns: ADPatternMatcherPass):
303303

304304
num_kv_patterns = _apply_pattern(gm, "Repeat KV", register_repeat_kv)
305305

306-
if num_kv_patterns > 0:
307-
self.config.run_shape_prop = True
308-
309306
info = TransformInfo(
310307
skipped=False,
311308
num_matches=num_kv_patterns,
Lines changed: 79 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
import operator
21
from typing import Tuple
32

43
import torch
54
from torch.fx import GraphModule
65

7-
from ...distributed.trtllm import is_trtllm_op_available
86
from ...models.factory import ModelFactory
97
from ...shim.interface import CachedSequenceInterface
10-
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
8+
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
119
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
1210

1311
# TODO: This is an overly simplified model that works well for vanilla Llama models.
@@ -18,187 +16,102 @@
1816
# * ...
1917

2018

21-
@TransformRegistry.register("fuse_collectives")
22-
class FuseCollectives(BaseTransform):
19+
def _allreduce_residual_rmsnorm_pattern(
20+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
21+
):
2322
"""
24-
Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance.
23+
Reference PyTorch composition of:
24+
y = all_reduce(x)
25+
z = residual + y
26+
normed = RMSNorm(z, weight, eps)
27+
Returns (normed, z)
2528
"""
2629

27-
def _apply(
28-
self,
29-
gm: GraphModule,
30-
cm: CachedSequenceInterface,
31-
factory: ModelFactory,
32-
shared_config: SharedConfig,
33-
) -> Tuple[GraphModule, TransformInfo]:
34-
num_gemm_collective_fusions = 0
35-
36-
# lookup for fused ops
37-
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
38-
lookup = {
39-
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
40-
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
41-
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
42-
}
43-
44-
# go through all nodes and find all_reduce nodes
45-
for node in gm.graph.nodes:
46-
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
47-
continue
48-
49-
# check if args are as expected
50-
assert len(node.args) == 1 and not len(node.kwargs), (
51-
"Unexpected args/kwargs for all_reduce"
52-
)
53-
54-
# retrieve parent and check a few conditions on the parent node
55-
parent_node = node.args[0]
56-
if not is_op(parent_node, lookup.keys()):
57-
continue
58-
if len(parent_node.users) > 1:
59-
continue
60-
61-
with gm.graph.inserting_before(node):
62-
# insert fused node
63-
fused_linear_collective_node = gm.graph.call_function(
64-
lookup[get_op_overload_packet(parent_node.target)],
65-
args=parent_node.args,
66-
kwargs=parent_node.kwargs,
67-
)
68-
node.replace_all_uses_with(fused_linear_collective_node)
69-
gm.graph.erase_node(node)
70-
gm.graph.erase_node(parent_node)
71-
num_gemm_collective_fusions += 1
30+
input_dtype = x.dtype
31+
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
32+
add = residual + hidden_states
7233

73-
info = TransformInfo(
74-
skipped=False,
75-
num_matches=num_gemm_collective_fusions,
76-
is_clean=False,
77-
has_valid_shapes=False,
78-
)
34+
hidden_states = add.to(torch.float32)
35+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
36+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
7937

80-
return gm, info
38+
normed = weight * hidden_states.to(input_dtype)
8139

40+
return normed, add
8241

83-
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
84-
class FuseAllreduceResidualRMSNorm(BaseTransform):
85-
"""Essentially, this transformation fuses the following operators into one allreduce trtllm implementation.
86-
87-
* target pattern:
88-
x = all_reduce(x)
89-
y = x + residual
90-
return rmsnorm(y), y
91-
* replacement:
92-
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
9342

43+
def _allreduce_residual_rmsnorm_pattern2(
44+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253
45+
):
46+
"""
47+
Reference PyTorch composition of:
48+
y = all_reduce(x)
49+
z = y + residual
50+
normed = RMSNorm(z, weight, eps)
51+
Returns (normed, z)
9452
"""
9553

54+
input_dtype = x.dtype
55+
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
56+
add = hidden_states + residual
57+
58+
hidden_states = add.to(torch.float32)
59+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
60+
hidden_states = hidden_states * torch.rsqrt(variance + eps)
61+
62+
normed = weight * hidden_states.to(input_dtype)
63+
64+
return normed, add
65+
66+
67+
def _allreduce_residual_rmsnorm_repl(
68+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float
69+
):
70+
return torch.ops.dist.fused_allreduce_residual_rmsnorm(x, residual, weight, eps)
71+
72+
73+
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
74+
class FuseAllreduceResidualRMSNorm(BaseTransform):
75+
"""Fuse (allreduce + residual add + RMSNorm) into one fused op with tuple output."""
76+
9677
def _apply(
9778
self,
9879
gm: GraphModule,
9980
cm: CachedSequenceInterface,
10081
factory: ModelFactory,
10182
shared_config: SharedConfig,
10283
) -> Tuple[GraphModule, TransformInfo]:
103-
if not is_trtllm_op_available():
104-
return gm, TransformInfo(
105-
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
106-
)
107-
108-
num_ar_r_rms_fusions = 0
109-
110-
def trace_and_fuse(allreduce_node, graph):
111-
# Check if all_reduce is followed by addition
112-
users = list(allreduce_node.users.keys())
113-
if len(users) != 1:
114-
return # Skip if all_reduce has more than one consumer
115-
add_node = users[0]
116-
117-
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
118-
# the Huggingface LlamaRMSNorm implementation as example for more details
119-
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
120-
# operand of pow and mul
121-
pow_node = get_user_if_pattern_match(
122-
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
123-
)
124-
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
125-
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
126-
rsqrt_node = get_user_if_pattern_match(
127-
add_eps_node, [torch.ops.aten.add, operator.add], 1
128-
)
129-
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
130-
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
131-
mul_node_2 = get_user_if_pattern_match(
132-
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
133-
)
134-
# check args of ops: pow(2) and mean(-1)
135-
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
136-
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
137-
138-
# Match found: Replace with fused operation
139-
if (
140-
to_copy_1
141-
and pow_node
142-
and mean_node
143-
and add_eps_node
144-
and rsqrt_node
145-
and mul_node_1
146-
and to_copy_2
147-
and mul_node_2
148-
and ARGS_MATCH
149-
):
150-
# Gather the inputs for the custom operation
151-
tensor = allreduce_node.args[0]
152-
# Identify the residual argument in the add operation
153-
# One of the args in add_node.args is the output of all_reduce
154-
# The same idea also applies to norm_weight
155-
residual = (
156-
add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
157-
)
158-
norm_weight = (
159-
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
160-
)
161-
eps = add_eps_node.args[1]
162-
163-
# Insert nodes
164-
with graph.inserting_before(allreduce_node):
165-
fused_node = graph.call_function(
166-
torch.ops.dist.fused_allreduce_residual_rmsnorm,
167-
args=(
168-
tensor,
169-
residual,
170-
norm_weight,
171-
eps,
172-
),
173-
)
174-
# Extract outputs from the tuple returned by `fused_node`
175-
final_output_node = gm.graph.create_node(
176-
"call_function",
177-
target=operator.getitem,
178-
args=(fused_node, 0),
179-
)
180-
add_output_node = gm.graph.create_node(
181-
"call_function",
182-
target=operator.getitem,
183-
args=(fused_node, 1),
184-
)
185-
186-
# Replace all uses of rmsnorm_node with final_output_node
187-
mul_node_2.replace_all_uses_with(final_output_node)
188-
189-
# Replace all uses of add_node with add_output_node
190-
add_node.replace_all_uses_with(add_output_node)
191-
192-
nonlocal num_ar_r_rms_fusions
193-
num_ar_r_rms_fusions += 1
194-
195-
# Traverse all nodes
196-
for node in gm.graph.nodes:
197-
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
198-
trace_and_fuse(allreduce_node=node, graph=gm.graph)
84+
patterns = ADPatternMatcherPass()
85+
86+
# Dummy shapes for tracing
87+
bsz, hidden = 8, 512
88+
dummy_args = [
89+
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # x
90+
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # residual
91+
torch.randn(hidden, device="meta", dtype=torch.bfloat16), # weight
92+
0.1253, # eps
93+
]
94+
95+
register_ad_pattern(
96+
search_fn=_allreduce_residual_rmsnorm_pattern,
97+
replace_fn=_allreduce_residual_rmsnorm_repl,
98+
patterns=patterns,
99+
dummy_args=dummy_args,
100+
op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)},
101+
scalar_workaround={"eps": 0.1253},
102+
)
103+
register_ad_pattern(
104+
search_fn=_allreduce_residual_rmsnorm_pattern2,
105+
replace_fn=_allreduce_residual_rmsnorm_repl,
106+
patterns=patterns,
107+
dummy_args=dummy_args,
108+
op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)},
109+
scalar_workaround={"eps": 0.1253},
110+
)
111+
112+
num_matches = patterns.apply(gm.graph)
199113

200114
info = TransformInfo(
201-
skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False
115+
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
202116
)
203-
204117
return gm, info

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def target_op(self):
500500
return torch.ops.auto_deploy.torch_linear_simple
501501

502502
def moe_op(self):
503-
return torch.ops.auto_deploy.torch_moe
503+
return torch.ops.auto_deploy.torch_moe.default
504504

505505
def scale_arg_indices(self) -> Dict[str, int]:
506506
return {}
@@ -517,7 +517,7 @@ def target_op(self):
517517
return torch.ops.auto_deploy.torch_quant_fp8_linear
518518

519519
def moe_op(self):
520-
return torch.ops.auto_deploy.torch_quant_fp8_moe
520+
return torch.ops.auto_deploy.torch_quant_fp8_moe.default
521521

522522
def scale_arg_indices(self) -> Dict[str, int]:
523523
return {"input_scale": 3, "weight_scale": 4}
@@ -534,7 +534,7 @@ def target_op(self):
534534
return torch.ops.auto_deploy.torch_quant_nvfp4_linear
535535

536536
def moe_op(self):
537-
return torch.ops.auto_deploy.torch_quant_nvfp4_moe
537+
return torch.ops.auto_deploy.torch_quant_nvfp4_moe.default
538538

539539
def scale_arg_indices(self) -> Dict[str, int]:
540540
return {"input_scale": 3, "weight_scale": 4, "alpha": 5}

0 commit comments

Comments
 (0)