1717from vllm .platforms import current_platform
1818
1919from .inductor_pass import enable_fake_mode
20- from .matcher_utils import MatcherQuant , MatcherRMSNorm
20+ from .matcher_utils import MatcherFusedAddRMSNorm , MatcherQuant , MatcherRMSNorm
2121from .vllm_inductor_pass import VllmInductorPass , VllmPatternMatcherPass
2222
2323logger = init_logger (__name__ )
@@ -92,7 +92,8 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey):
9292 f"unsupported fused rmsnorm+quant op for { key } "
9393 self .FUSED_OP = FUSED_OPS [key ]
9494
95- self .rmsnorm_matcher = MatcherRMSNorm (epsilon )
95+ self .rmsnorm_matcher = MatcherRMSNorm (epsilon ) if not key .fused_add \
96+ else MatcherFusedAddRMSNorm (epsilon )
9697 self .quant_matcher = MatcherQuant (key .quant )
9798
9899
@@ -133,8 +134,8 @@ def replacement(input: torch.Tensor, weight: torch.Tensor,
133134 return at [1 ]
134135
135136 inputs = [
136- empty_fp32 ( 5 , 4 ), # input # TODO: rms_input
137- empty_bf16 ( 4 , ), # weight
137+ # input, weight
138+ * self . rmsnorm_matcher . inputs (),
138139 empty_fp32 (1 , 1 ) # scale
139140 ]
140141 pattern (* inputs )
@@ -157,16 +158,16 @@ def __init__(self,
157158
158159 def register (self , pm_pass : PatternMatcherPass ):
159160
160- def pattern (input : torch .Tensor , residual : torch .Tensor ,
161- weight : torch .Tensor , scale : torch .Tensor ):
161+ def pattern (input : torch .Tensor , weight : torch .Tensor ,
162+ residual : torch .Tensor , scale : torch .Tensor ):
162163 result_rms , residual = self .rmsnorm_matcher (
163164 input , weight , residual )
164165 result , _ = self .quant_matcher (result_rms , scale )
165166
166167 return result , residual
167168
168- def replacement (input : torch .Tensor , residual : torch .Tensor ,
169- weight : torch .Tensor , scale : torch .Tensor ):
169+ def replacement (input : torch .Tensor , weight : torch .Tensor ,
170+ residual : torch .Tensor , scale : torch .Tensor ):
170171 # In case we're matching native rms-norm, conversions might be
171172 # optimized out. We convert here just to be safe.
172173 input = input .to (dtype = torch .float16 ) # TODO model dtype
@@ -185,11 +186,8 @@ def replacement(input: torch.Tensor, residual: torch.Tensor,
185186 return at [1 ], at [2 ]
186187
187188 inputs = [
188- # TODO: maybe 32bit for torch impl? yes to resolve bug
189- # TODO dtype doesn't seem to matter? it does matter for what cvts get traced
190- empty_bf16 (5 , 4 ), # input
191- empty_bf16 (5 , 4 ), # residual
192- empty_bf16 (4 , ), # weight
189+ # input, weight, residual
190+ * self .rmsnorm_matcher .inputs (),
193191 empty_fp32 (1 , 1 ) # scale
194192 ]
195193
@@ -242,15 +240,10 @@ def replacement(input: torch.Tensor, weight: torch.Tensor):
242240 # result, scale
243241 return at [1 ], at [2 ]
244242
245- inputs = [
246- empty_bf16 (5 , 4 ), # input
247- empty_bf16 (4 ), # weight
248- ]
249-
250243 pm .register_replacement (
251244 pattern ,
252245 replacement ,
253- inputs ,
246+ self . rmsnorm_matcher . inputs () ,
254247 pm .fwd_only ,
255248 pm_pass ,
256249 )
@@ -272,16 +265,16 @@ def __init__(self,
272265
273266 def register (self , pm_pass : PatternMatcherPass ):
274267
275- def pattern (input : torch .Tensor , residual : torch .Tensor ,
276- weight : torch .Tensor ):
268+ def pattern (input : torch .Tensor , weight : torch .Tensor ,
269+ residual : torch .Tensor ):
277270 result_rms , residual = self .rmsnorm_matcher (
278271 input , weight , residual )
279272 result , scale = self .quant_matcher (result_rms )
280273
281274 return result , residual , scale
282275
283- def replacement (input : torch .Tensor , residual : torch .Tensor ,
284- weight : torch .Tensor ):
276+ def replacement (input : torch .Tensor , weight : torch .Tensor ,
277+ residual : torch .Tensor ):
285278 # In case we're matching native rms-norm, conversions might be
286279 # optimized out. We convert here just to be safe.
287280 input = input .to (dtype = torch .float16 ) # TODO model dtype
@@ -301,16 +294,10 @@ def replacement(input: torch.Tensor, residual: torch.Tensor,
301294 # result, residual, scale
302295 return at [1 ], at [3 ], at [2 ]
303296
304- inputs = [
305- empty_bf16 (5 , 4 ), # input
306- empty_bf16 (5 , 4 ), # residual
307- empty_bf16 (4 ), # weight
308- ]
309-
310297 pm .register_replacement (
311298 pattern ,
312299 replacement ,
313- inputs ,
300+ self . rmsnorm_matcher . inputs () ,
314301 pm .fwd_only ,
315302 pm_pass ,
316303 )
0 commit comments