|
1 | | -import operator |
2 | 1 | from typing import Tuple |
3 | 2 |
|
4 | 3 | import torch |
5 | 4 | from torch.fx import GraphModule |
6 | 5 |
|
7 | | -from ...distributed.trtllm import is_trtllm_op_available |
8 | 6 | from ...models.factory import ModelFactory |
9 | 7 | from ...shim.interface import CachedSequenceInterface |
10 | | -from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op |
| 8 | +from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern |
11 | 9 | from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry |
12 | 10 |
|
13 | 11 | # TODO: This is an overly simplified model that works well for vanilla Llama models. |
|
18 | 16 | # * ... |
19 | 17 |
|
20 | 18 |
|
21 | | -@TransformRegistry.register("fuse_collectives") |
22 | | -class FuseCollectives(BaseTransform): |
| 19 | +def _allreduce_residual_rmsnorm_pattern( |
| 20 | + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253 |
| 21 | +): |
23 | 22 | """ |
24 | | - Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance. |
| 23 | + Reference PyTorch composition of: |
| 24 | + y = all_reduce(x) |
| 25 | + z = residual + y |
| 26 | + normed = RMSNorm(z, weight, eps) |
| 27 | + Returns (normed, z) |
25 | 28 | """ |
26 | 29 |
|
27 | | - def _apply( |
28 | | - self, |
29 | | - gm: GraphModule, |
30 | | - cm: CachedSequenceInterface, |
31 | | - factory: ModelFactory, |
32 | | - shared_config: SharedConfig, |
33 | | - ) -> Tuple[GraphModule, TransformInfo]: |
34 | | - num_gemm_collective_fusions = 0 |
35 | | - |
36 | | - # lookup for fused ops |
37 | | - # TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly. |
38 | | - lookup = { |
39 | | - torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, |
40 | | - torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce, |
41 | | - torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce, |
42 | | - } |
43 | | - |
44 | | - # go through all nodes and find all_reduce nodes |
45 | | - for node in gm.graph.nodes: |
46 | | - if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): |
47 | | - continue |
48 | | - |
49 | | - # check if args are as expected |
50 | | - assert len(node.args) == 1 and not len(node.kwargs), ( |
51 | | - "Unexpected args/kwargs for all_reduce" |
52 | | - ) |
53 | | - |
54 | | - # retrieve parent and check a few conditions on the parent node |
55 | | - parent_node = node.args[0] |
56 | | - if not is_op(parent_node, lookup.keys()): |
57 | | - continue |
58 | | - if len(parent_node.users) > 1: |
59 | | - continue |
60 | | - |
61 | | - with gm.graph.inserting_before(node): |
62 | | - # insert fused node |
63 | | - fused_linear_collective_node = gm.graph.call_function( |
64 | | - lookup[get_op_overload_packet(parent_node.target)], |
65 | | - args=parent_node.args, |
66 | | - kwargs=parent_node.kwargs, |
67 | | - ) |
68 | | - node.replace_all_uses_with(fused_linear_collective_node) |
69 | | - gm.graph.erase_node(node) |
70 | | - gm.graph.erase_node(parent_node) |
71 | | - num_gemm_collective_fusions += 1 |
| 30 | + input_dtype = x.dtype |
| 31 | + hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x) |
| 32 | + add = residual + hidden_states |
72 | 33 |
|
73 | | - info = TransformInfo( |
74 | | - skipped=False, |
75 | | - num_matches=num_gemm_collective_fusions, |
76 | | - is_clean=False, |
77 | | - has_valid_shapes=False, |
78 | | - ) |
| 34 | + hidden_states = add.to(torch.float32) |
| 35 | + variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| 36 | + hidden_states = hidden_states * torch.rsqrt(variance + eps) |
79 | 37 |
|
80 | | - return gm, info |
| 38 | + normed = weight * hidden_states.to(input_dtype) |
81 | 39 |
|
| 40 | + return normed, add |
82 | 41 |
|
83 | | -@TransformRegistry.register("fuse_allreduce_residual_rmsnorm") |
84 | | -class FuseAllreduceResidualRMSNorm(BaseTransform): |
85 | | - """Essentially, this transformation fuses the following operators into one allreduce trtllm implementation. |
86 | | -
|
87 | | - * target pattern: |
88 | | - x = all_reduce(x) |
89 | | - y = x + residual |
90 | | - return rmsnorm(y), y |
91 | | - * replacement: |
92 | | - fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps) |
93 | 42 |
|
| 43 | +def _allreduce_residual_rmsnorm_pattern2( |
| 44 | + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253 |
| 45 | +): |
| 46 | + """ |
| 47 | + Reference PyTorch composition of: |
| 48 | + y = all_reduce(x) |
| 49 | + z = y + residual |
| 50 | + normed = RMSNorm(z, weight, eps) |
| 51 | + Returns (normed, z) |
94 | 52 | """ |
95 | 53 |
|
| 54 | + input_dtype = x.dtype |
| 55 | + hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x) |
| 56 | + add = hidden_states + residual |
| 57 | + |
| 58 | + hidden_states = add.to(torch.float32) |
| 59 | + variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| 60 | + hidden_states = hidden_states * torch.rsqrt(variance + eps) |
| 61 | + |
| 62 | + normed = weight * hidden_states.to(input_dtype) |
| 63 | + |
| 64 | + return normed, add |
| 65 | + |
| 66 | + |
| 67 | +def _allreduce_residual_rmsnorm_repl( |
| 68 | + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float |
| 69 | +): |
| 70 | + return torch.ops.dist.fused_allreduce_residual_rmsnorm(x, residual, weight, eps) |
| 71 | + |
| 72 | + |
| 73 | +@TransformRegistry.register("fuse_allreduce_residual_rmsnorm") |
| 74 | +class FuseAllreduceResidualRMSNorm(BaseTransform): |
| 75 | + """Fuse (allreduce + residual add + RMSNorm) into one fused op with tuple output.""" |
| 76 | + |
96 | 77 | def _apply( |
97 | 78 | self, |
98 | 79 | gm: GraphModule, |
99 | 80 | cm: CachedSequenceInterface, |
100 | 81 | factory: ModelFactory, |
101 | 82 | shared_config: SharedConfig, |
102 | 83 | ) -> Tuple[GraphModule, TransformInfo]: |
103 | | - if not is_trtllm_op_available(): |
104 | | - return gm, TransformInfo( |
105 | | - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True |
106 | | - ) |
107 | | - |
108 | | - num_ar_r_rms_fusions = 0 |
109 | | - |
110 | | - def trace_and_fuse(allreduce_node, graph): |
111 | | - # Check if all_reduce is followed by addition |
112 | | - users = list(allreduce_node.users.keys()) |
113 | | - if len(users) != 1: |
114 | | - return # Skip if all_reduce has more than one consumer |
115 | | - add_node = users[0] |
116 | | - |
117 | | - # Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer |
118 | | - # the Huggingface LlamaRMSNorm implementation as example for more details |
119 | | - to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2) |
120 | | - # operand of pow and mul |
121 | | - pow_node = get_user_if_pattern_match( |
122 | | - to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2 |
123 | | - ) |
124 | | - mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1) |
125 | | - add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1) |
126 | | - rsqrt_node = get_user_if_pattern_match( |
127 | | - add_eps_node, [torch.ops.aten.add, operator.add], 1 |
128 | | - ) |
129 | | - mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1) |
130 | | - to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1) |
131 | | - mul_node_2 = get_user_if_pattern_match( |
132 | | - to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1 |
133 | | - ) |
134 | | - # check args of ops: pow(2) and mean(-1) |
135 | | - ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent |
136 | | - ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions |
137 | | - |
138 | | - # Match found: Replace with fused operation |
139 | | - if ( |
140 | | - to_copy_1 |
141 | | - and pow_node |
142 | | - and mean_node |
143 | | - and add_eps_node |
144 | | - and rsqrt_node |
145 | | - and mul_node_1 |
146 | | - and to_copy_2 |
147 | | - and mul_node_2 |
148 | | - and ARGS_MATCH |
149 | | - ): |
150 | | - # Gather the inputs for the custom operation |
151 | | - tensor = allreduce_node.args[0] |
152 | | - # Identify the residual argument in the add operation |
153 | | - # One of the args in add_node.args is the output of all_reduce |
154 | | - # The same idea also applies to norm_weight |
155 | | - residual = ( |
156 | | - add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1] |
157 | | - ) |
158 | | - norm_weight = ( |
159 | | - mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1] |
160 | | - ) |
161 | | - eps = add_eps_node.args[1] |
162 | | - |
163 | | - # Insert nodes |
164 | | - with graph.inserting_before(allreduce_node): |
165 | | - fused_node = graph.call_function( |
166 | | - torch.ops.dist.fused_allreduce_residual_rmsnorm, |
167 | | - args=( |
168 | | - tensor, |
169 | | - residual, |
170 | | - norm_weight, |
171 | | - eps, |
172 | | - ), |
173 | | - ) |
174 | | - # Extract outputs from the tuple returned by `fused_node` |
175 | | - final_output_node = gm.graph.create_node( |
176 | | - "call_function", |
177 | | - target=operator.getitem, |
178 | | - args=(fused_node, 0), |
179 | | - ) |
180 | | - add_output_node = gm.graph.create_node( |
181 | | - "call_function", |
182 | | - target=operator.getitem, |
183 | | - args=(fused_node, 1), |
184 | | - ) |
185 | | - |
186 | | - # Replace all uses of rmsnorm_node with final_output_node |
187 | | - mul_node_2.replace_all_uses_with(final_output_node) |
188 | | - |
189 | | - # Replace all uses of add_node with add_output_node |
190 | | - add_node.replace_all_uses_with(add_output_node) |
191 | | - |
192 | | - nonlocal num_ar_r_rms_fusions |
193 | | - num_ar_r_rms_fusions += 1 |
194 | | - |
195 | | - # Traverse all nodes |
196 | | - for node in gm.graph.nodes: |
197 | | - if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce): |
198 | | - trace_and_fuse(allreduce_node=node, graph=gm.graph) |
| 84 | + patterns = ADPatternMatcherPass() |
| 85 | + |
| 86 | + # Dummy shapes for tracing |
| 87 | + bsz, hidden = 8, 512 |
| 88 | + dummy_args = [ |
| 89 | + torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # x |
| 90 | + torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # residual |
| 91 | + torch.randn(hidden, device="meta", dtype=torch.bfloat16), # weight |
| 92 | + 0.1253, # eps |
| 93 | + ] |
| 94 | + |
| 95 | + register_ad_pattern( |
| 96 | + search_fn=_allreduce_residual_rmsnorm_pattern, |
| 97 | + replace_fn=_allreduce_residual_rmsnorm_repl, |
| 98 | + patterns=patterns, |
| 99 | + dummy_args=dummy_args, |
| 100 | + op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)}, |
| 101 | + scalar_workaround={"eps": 0.1253}, |
| 102 | + ) |
| 103 | + register_ad_pattern( |
| 104 | + search_fn=_allreduce_residual_rmsnorm_pattern2, |
| 105 | + replace_fn=_allreduce_residual_rmsnorm_repl, |
| 106 | + patterns=patterns, |
| 107 | + dummy_args=dummy_args, |
| 108 | + op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)}, |
| 109 | + scalar_workaround={"eps": 0.1253}, |
| 110 | + ) |
| 111 | + |
| 112 | + num_matches = patterns.apply(gm.graph) |
199 | 113 |
|
200 | 114 | info = TransformInfo( |
201 | | - skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False |
| 115 | + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False |
202 | 116 | ) |
203 | | - |
204 | 117 | return gm, info |
0 commit comments