Skip to content

Commit fe89c0b

Browse files
committed
ALL WORKS
Signed-off-by: Luka Govedič <[email protected]>
1 parent cec037e commit fe89c0b

File tree

2 files changed

+110
-62
lines changed

2 files changed

+110
-62
lines changed

vllm/compilation/fusion.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.platforms import current_platform
1818

1919
from .inductor_pass import enable_fake_mode
20-
from .matcher_utils import MatcherQuant, MatcherRMSNorm
20+
from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuant, MatcherRMSNorm
2121
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
2222

2323
logger = init_logger(__name__)
@@ -92,7 +92,8 @@ def __init__(self, epsilon: float, key: FusedRMSQuantKey):
9292
f"unsupported fused rmsnorm+quant op for {key}"
9393
self.FUSED_OP = FUSED_OPS[key]
9494

95-
self.rmsnorm_matcher = MatcherRMSNorm(epsilon)
95+
self.rmsnorm_matcher = MatcherRMSNorm(epsilon) if not key.fused_add \
96+
else MatcherFusedAddRMSNorm(epsilon)
9697
self.quant_matcher = MatcherQuant(key.quant)
9798

9899

@@ -133,8 +134,8 @@ def replacement(input: torch.Tensor, weight: torch.Tensor,
133134
return at[1]
134135

135136
inputs = [
136-
empty_fp32(5, 4), # input # TODO: rms_input
137-
empty_bf16(4, ), # weight
137+
# input, weight
138+
*self.rmsnorm_matcher.inputs(),
138139
empty_fp32(1, 1) # scale
139140
]
140141
pattern(*inputs)
@@ -157,16 +158,16 @@ def __init__(self,
157158

158159
def register(self, pm_pass: PatternMatcherPass):
159160

160-
def pattern(input: torch.Tensor, residual: torch.Tensor,
161-
weight: torch.Tensor, scale: torch.Tensor):
161+
def pattern(input: torch.Tensor, weight: torch.Tensor,
162+
residual: torch.Tensor, scale: torch.Tensor):
162163
result_rms, residual = self.rmsnorm_matcher(
163164
input, weight, residual)
164165
result, _ = self.quant_matcher(result_rms, scale)
165166

166167
return result, residual
167168

168-
def replacement(input: torch.Tensor, residual: torch.Tensor,
169-
weight: torch.Tensor, scale: torch.Tensor):
169+
def replacement(input: torch.Tensor, weight: torch.Tensor,
170+
residual: torch.Tensor, scale: torch.Tensor):
170171
# In case we're matching native rms-norm, conversions might be
171172
# optimized out. We convert here just to be safe.
172173
input = input.to(dtype=torch.float16) # TODO model dtype
@@ -185,11 +186,8 @@ def replacement(input: torch.Tensor, residual: torch.Tensor,
185186
return at[1], at[2]
186187

187188
inputs = [
188-
# TODO: maybe 32bit for torch impl? yes to resolve bug
189-
# TODO dtype doesn't seem to matter? it does matter for what cvts get traced
190-
empty_bf16(5, 4), # input
191-
empty_bf16(5, 4), # residual
192-
empty_bf16(4, ), # weight
189+
# input, weight, residual
190+
*self.rmsnorm_matcher.inputs(),
193191
empty_fp32(1, 1) # scale
194192
]
195193

@@ -242,15 +240,10 @@ def replacement(input: torch.Tensor, weight: torch.Tensor):
242240
# result, scale
243241
return at[1], at[2]
244242

245-
inputs = [
246-
empty_bf16(5, 4), # input
247-
empty_bf16(4), # weight
248-
]
249-
250243
pm.register_replacement(
251244
pattern,
252245
replacement,
253-
inputs,
246+
self.rmsnorm_matcher.inputs(),
254247
pm.fwd_only,
255248
pm_pass,
256249
)
@@ -272,16 +265,16 @@ def __init__(self,
272265

273266
def register(self, pm_pass: PatternMatcherPass):
274267

275-
def pattern(input: torch.Tensor, residual: torch.Tensor,
276-
weight: torch.Tensor):
268+
def pattern(input: torch.Tensor, weight: torch.Tensor,
269+
residual: torch.Tensor):
277270
result_rms, residual = self.rmsnorm_matcher(
278271
input, weight, residual)
279272
result, scale = self.quant_matcher(result_rms)
280273

281274
return result, residual, scale
282275

283-
def replacement(input: torch.Tensor, residual: torch.Tensor,
284-
weight: torch.Tensor):
276+
def replacement(input: torch.Tensor, weight: torch.Tensor,
277+
residual: torch.Tensor):
285278
# In case we're matching native rms-norm, conversions might be
286279
# optimized out. We convert here just to be safe.
287280
input = input.to(dtype=torch.float16) # TODO model dtype
@@ -301,16 +294,10 @@ def replacement(input: torch.Tensor, residual: torch.Tensor,
301294
# result, residual, scale
302295
return at[1], at[3], at[2]
303296

304-
inputs = [
305-
empty_bf16(5, 4), # input
306-
empty_bf16(5, 4), # residual
307-
empty_bf16(4), # weight
308-
]
309-
310297
pm.register_replacement(
311298
pattern,
312299
replacement,
313-
inputs,
300+
self.rmsnorm_matcher.inputs(),
314301
pm.fwd_only,
315302
pm_pass,
316303
)

vllm/compilation/matcher_utils.py

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Optional, Union
3+
from abc import ABC, abstractmethod
4+
from typing import Optional
45

56
import torch
67
from torch._higher_order_ops import auto_functionalized
@@ -31,55 +32,71 @@
3132
# kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
3233

3334

34-
class MatcherRMSNorm: # TODO separate residual and not residual
35+
class MatcherCustomOp(ABC):
3536

36-
def __init__(self, epsilon: float, enabled: Optional[bool] = None):
37-
self.epsilon = epsilon
37+
def __init__(self, enabled: bool):
38+
self.model_dtype = get_current_vllm_config().model_config.dtype
39+
40+
self.enabled = enabled
41+
self.forward = self.forward_custom if enabled else self.forward_native
42+
43+
@abstractmethod
44+
def forward_custom(self, *args, **kws):
45+
pass
46+
47+
@abstractmethod
48+
def forward_native(self, *args, **kws):
49+
pass
3850

51+
def __call__(self, *args, **kws):
52+
return self.forward(*args, **kws)
53+
54+
def empty(self, *args, **kws):
55+
return torch.empty(*args, dtype=self.model_dtype, device="cuda", **kws)
56+
57+
def empty_f32(self, *args, **kws):
58+
return torch.empty(*args, dtype=torch.float32, device="cuda", **kws)
59+
60+
61+
class MatcherRMSNorm(MatcherCustomOp):
62+
63+
def __init__(self, epsilon: float, enabled: Optional[bool] = None):
3964
if enabled is None:
4065
# TODO either pass config to enabled or set it globally
4166
# (global during pass init seems reasonable)
4267
enabled = RMSNorm.enabled()
4368

44-
self.forward = self.forward_custom if enabled else self.forward_native
45-
self.model_dtype = get_current_vllm_config().model_config.dtype
46-
print(self.model_dtype)
69+
super().__init__(enabled)
70+
self.epsilon = epsilon
4771

4872
def inputs(self):
49-
return
73+
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
74+
weight = self.empty(16, )
75+
return [input, weight]
5076

5177
def forward_custom(
5278
self,
5379
input: torch.Tensor,
5480
weight: torch.Tensor,
5581
residual: Optional[torch.Tensor] = None,
56-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
57-
if residual is None:
58-
result = torch.empty_like(input)
59-
_, result = auto_functionalized(
60-
RMS_OP,
61-
result=result,
62-
input=input,
63-
weight=weight,
64-
epsilon=self.epsilon,
65-
)
66-
67-
return result
68-
else:
69-
_, result, residual = auto_functionalized(RMS_ADD_OP,
70-
input=input,
71-
residual=residual,
72-
weight=weight,
73-
epsilon=self.epsilon)
82+
) -> torch.Tensor:
83+
result = torch.empty_like(input)
84+
_, result = auto_functionalized(
85+
RMS_OP,
86+
result=result,
87+
input=input,
88+
weight=weight,
89+
epsilon=self.epsilon,
90+
)
7491

75-
return result, residual
92+
return result
7693

7794
def forward_native(
7895
self,
7996
input: torch.Tensor,
8097
weight: torch.Tensor,
8198
residual: Optional[torch.Tensor] = None,
82-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
99+
) -> torch.Tensor:
83100
x = input.to(torch.float32)
84101
if residual is not None:
85102
x = x + residual
@@ -94,13 +111,57 @@ def forward_native(
94111

95112
return x if residual is None else (x, residual)
96113

97-
def __call__(
114+
115+
class MatcherFusedAddRMSNorm(MatcherCustomOp):
116+
117+
def __init__(self, epsilon: float, enabled: Optional[bool] = None):
118+
if enabled is None:
119+
# TODO either pass config to enabled or set it globally
120+
# (global during pass init seems reasonable)
121+
enabled = RMSNorm.enabled()
122+
123+
super().__init__(enabled)
124+
self.epsilon = epsilon
125+
126+
def inputs(self):
127+
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
128+
weight = self.empty(16, )
129+
residual = self.empty(5, 16)
130+
return [input, weight, residual]
131+
132+
def forward_custom(
98133
self,
99134
input: torch.Tensor,
100135
weight: torch.Tensor,
101-
residual: Optional[torch.Tensor] = None,
102-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
103-
return self.forward(input, weight, residual)
136+
residual: torch.Tensor,
137+
) -> tuple[torch.Tensor, torch.Tensor]:
138+
_, result, residual = auto_functionalized(RMS_ADD_OP,
139+
input=input,
140+
residual=residual,
141+
weight=weight,
142+
epsilon=self.epsilon)
143+
144+
return result, residual
145+
146+
def forward_native(
147+
self,
148+
input: torch.Tensor,
149+
weight: torch.Tensor,
150+
residual: torch.Tensor,
151+
) -> tuple[torch.Tensor, torch.Tensor]:
152+
x = input.to(torch.float32)
153+
if residual is not None:
154+
x = x + residual
155+
residual = x.to(self.model_dtype)
156+
157+
variance = x.pow(2).mean(dim=-1, keepdim=True)
158+
159+
x = x * torch.rsqrt(variance + self.epsilon)
160+
x = x.to(self.model_dtype)
161+
if weight is not None:
162+
x = x * weight
163+
164+
return x if residual is None else (x, residual)
104165

105166

106167
class MatcherQuant:

0 commit comments

Comments
 (0)