Skip to content

Commit 4734bfe

Browse files
committed
address comments
Signed-off-by: cascade812 <[email protected]>
1 parent 623862d commit 4734bfe

File tree

7 files changed

+75
-84
lines changed

7 files changed

+75
-84
lines changed

tests/compile/backend.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from torch import fx
77

8+
from vllm.compilation.fx_utils import (find_specified_fn,
9+
find_specified_fn_maybe)
810
from vllm.compilation.inductor_pass import InductorPass
911
from vllm.config import get_current_vllm_config
1012

@@ -44,3 +46,19 @@ def post_pass(self, graph: fx.Graph):
4446
self.graph_post_pass = deepcopy(graph)
4547
# assign by reference, will reflect the final state of the graph
4648
self.final_graph = graph
49+
50+
def check_before_ops(self, ops,
51+
find_fn=find_specified_fn, \
52+
find_fn_maybe=find_specified_fn_maybe, \
53+
ops_fully_replaced=True):
54+
for op in ops:
55+
find_fn(self.graph_pre_pass.nodes, op)
56+
if ops_fully_replaced:
57+
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None
58+
59+
def check_after_ops(self, ops,
60+
find_fn=find_specified_fn, \
61+
find_fn_maybe=find_specified_fn_maybe):
62+
for op in ops:
63+
find_fn(self.graph_post_pass.nodes, op)
64+
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None

tests/compile/test_async_tp.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import vllm.envs as envs
99
from vllm.compilation.collective_fusion import AsyncTPPass
10-
from vllm.compilation.fx_utils import (find_specified_fn,
11-
find_specified_fn_maybe)
1210
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
1311
PassConfig, VllmConfig)
1412
from vllm.distributed import (tensor_model_parallel_all_gather,
@@ -93,7 +91,7 @@ def ops_in_model_after(self):
9391

9492

9593
@multi_gpu_test(num_gpus=2)
96-
@pytest.mark.parametrize("test_model", ["TestMMRSModel", "TestAGMMModel"])
94+
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
9795
@pytest.mark.parametrize("batch_size", [8])
9896
@pytest.mark.parametrize("seq_len", [16])
9997
@pytest.mark.parametrize("hidden_size", [16])
@@ -117,7 +115,8 @@ def run_torch_spawn(fn, nprocs):
117115

118116

119117
def async_tp_pass_on_test_model(local_rank: int, world_size: int,
120-
test_model: str, batch_size: int, seq_len: int,
118+
test_model_cls: torch.nn.Module,
119+
batch_size: int, seq_len: int,
121120
hidden_size: int, dtype: torch.dtype):
122121
current_platform.seed_everything(0)
123122

@@ -158,12 +157,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
158157
async_tp_pass = AsyncTPPass(vllm_config)
159158
backend = TestBackend(async_tp_pass)
160159

161-
if test_model == "TestMMRSModel":
162-
model = TestMMRSModel(hidden_size)
163-
elif test_model == "TestAGMMModel":
164-
model = TestAGMMModel(hidden_size)
165-
else:
166-
raise ValueError(f"Unknown model: {test_model}")
160+
model = test_model_cls(hidden_size)
167161

168162
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
169163
dtype=dtype,
@@ -172,21 +166,14 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
172166
compiled_model = torch.compile(model, backend=backend)
173167
compiled_model(hidden_states)
174168

175-
# Check substitution worked
176-
pre_nodes = backend.graph_pre_pass.nodes
177-
post_nodes = backend.graph_post_pass.nodes
178-
179-
# In pre-nodes, all reduce should exist,
169+
# In pre-nodes, all gather or reduce scatter should exist,
180170
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
181-
for op in model.ops_in_model_before():
182-
find_specified_fn(pre_nodes, op)
183-
for op in model.ops_in_model_after():
184-
assert find_specified_fn_maybe(pre_nodes, op) is None
171+
backend.check_before_ops(model.ops_in_model_before(),
172+
ops_fully_replaced=False)
185173

186174
# In post-nodes, fused_matmul_reduce_scatter or \
187175
# fused_all_gather_matmul should exist
188-
for op in model.ops_in_model_after():
189-
find_specified_fn(post_nodes, op)
176+
backend.check_after_ops(model.ops_in_model_after())
190177

191178

192179
@create_new_process_for_each_test()
@@ -258,12 +245,9 @@ def test_async_tp_pass_correctness(
258245
"mp",
259246
]
260247

261-
try:
262-
compare_two_settings(model_id,
263-
aysnc_tp_args,
264-
tp_args,
265-
async_tp_env,
266-
tp_env,
267-
method="generate")
268-
except Exception:
269-
raise
248+
compare_two_settings(model_id,
249+
aysnc_tp_args,
250+
tp_args,
251+
async_tp_env,
252+
tp_env,
253+
method="generate")

tests/compile/test_fusion.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
2929
self.cutlass_fp8_enabled = cutlass_fp8_enabled
3030
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
3131
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
32+
self.key = QuantKey(dtype=FP8_DTYPE,
33+
static=static,
34+
per_tensor=static,
35+
symmetric=True)
3236
if static:
3337
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
3438
else:
@@ -59,6 +63,15 @@ def forward(self, x):
5963
y3, resid = self.norm[2](x3, resid) # use resid here
6064
return y3
6165

66+
def ops_in_model_before(self):
67+
return [QUANT_OPS[self.key]]
68+
69+
def ops_in_model_after(self):
70+
return [
71+
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
72+
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
73+
]
74+
6275

6376
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
6477
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
107120

108121
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
109122

110-
# Check substitution worked
111-
pre_nodes = backend.graph_pre_pass.nodes
112-
post_nodes = backend.graph_post_pass.nodes
113-
114-
# static is per-tensor, dynamic is per-token
115-
key = QuantKey(dtype=FP8_DTYPE,
116-
static=static,
117-
per_tensor=static,
118-
symmetric=True)
119-
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
120-
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
121-
fp8_quant = QUANT_OPS[key]
122-
123123
# In pre-nodes, fp8 quant should be there and fused kernels should not
124-
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
125-
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
126-
find_auto_fn(pre_nodes, fp8_quant)
124+
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
125+
find_auto_fn_maybe)
127126

128127
# In post-nodes, fused kernels should be there and fp8 quant should not
129-
find_auto_fn(post_nodes, rms_quant)
130-
find_auto_fn(post_nodes, add_rms_quant)
131-
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
128+
backend.check_before_ops(model.ops_in_model_after(), find_auto_fn,
129+
find_auto_fn_maybe)

tests/compile/test_sequence_parallelism.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66
import vllm.envs as envs
77
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
8-
from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
9-
find_specified_fn,
10-
find_specified_fn_maybe, is_func)
8+
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
119
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
1210
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
1311
PassConfig, VllmConfig)
@@ -21,17 +19,6 @@
2119
from ..utils import multi_gpu_test
2220
from .backend import TestBackend
2321

24-
OPS_IN_MODEL_BEFORE = [
25-
torch.ops.vllm.all_reduce.default,
26-
]
27-
28-
OPS_IN_MODEL_AFTER = [
29-
torch.ops.vllm.reduce_scatter.default,
30-
torch.ops.vllm.all_gather.default,
31-
]
32-
33-
OPS_IN_MODEL = [torch.ops._C.fused_add_rms_norm.default]
34-
3522
prompts = [
3623
"Hello, my name is",
3724
"The president of the United States is",
@@ -78,6 +65,18 @@ def forward(self, hidden_states, residual):
7865

7966
return norm_output, residual_output
8067

68+
def ops_in_model_before(self):
69+
return [torch.ops.vllm.all_reduce.default]
70+
71+
def ops_in_model_after(self):
72+
return [
73+
torch.ops.vllm.reduce_scatter.default,
74+
torch.ops.vllm.all_gather.default
75+
]
76+
77+
def ops_in_model(self):
78+
return [torch.ops._C.fused_add_rms_norm.default]
79+
8180

8281
@multi_gpu_test(num_gpus=2)
8382
@pytest.mark.parametrize("batch_size", [8])
@@ -156,34 +155,24 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
156155
compiled_model_func = torch.compile(model, backend=backend_func)
157156
compiled_model_func(hidden_states, residual)
158157

159-
# Check substitution worked
160-
pre_nodes = backend_no_func.graph_pre_pass.nodes
161-
post_nodes = backend_no_func.graph_post_pass.nodes
162-
163158
# In pre-nodes, all reduce should be there,
164159
# reduce scatter and all gather should not
165-
for op in OPS_IN_MODEL_BEFORE:
166-
find_specified_fn(pre_nodes, op)
167-
for op in OPS_IN_MODEL_AFTER:
168-
assert find_specified_fn_maybe(pre_nodes, op) is None
160+
backend_no_func.check_before_ops(model.ops_in_model_before())
169161

170162
# In post-nodes, reduce scatter and all gather should be there,
171163
# all reduce should not
172-
for op in OPS_IN_MODEL_AFTER:
173-
find_specified_fn(post_nodes, op)
174-
for op in OPS_IN_MODEL_BEFORE:
175-
assert find_specified_fn_maybe(post_nodes, op) is None
164+
backend_no_func.check_after_ops(model.ops_in_model_after())
176165

177166
# check if the functionalization pass is applied
178-
for op in OPS_IN_MODEL:
167+
for op in model.ops_in_model():
179168
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
180169
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
181170
op) is None # noqa: E501
182171

183172
# make sure the ops were all de-functionalized
184173
found = dict()
185174
for node in backend_func.graph_post_pass.nodes:
186-
for op in OPS_IN_MODEL:
175+
for op in model.ops_in_model():
187176
if is_func(node, op):
188177
found[op] = True
189-
assert all(found[op] for op in OPS_IN_MODEL)
178+
assert all(found[op] for op in model.ops_in_model())

vllm/compilation/collective_fusion.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,11 @@ def __init__(self, config: VllmConfig):
106106
enable_symm_mem_for_group(get_tp_group().device_group.group_name)
107107
self.patterns: PatternMatcherPass = PatternMatcherPass(
108108
pass_name="async_tp_pass")
109-
GEMMReduceScatterPattern(self.dtype,
109+
GEMMReduceScatterPattern(self.model_dtype,
110110
self.device).register(self.patterns)
111111

112-
AllGatherGEMMPattern(self.dtype, self.device).register(self.patterns)
112+
AllGatherGEMMPattern(self.model_dtype,
113+
self.device).register(self.patterns)
113114

114115
def is_applicable_for_shape(self, shape: Optional[int]) -> bool:
115116
# only do replace for specific shapes

vllm/compilation/sequence_parallelism.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,12 @@ def __init__(self, config: VllmConfig):
243243
pass_name="sequence_parallelism_pass")
244244
for epsilon in [1e-5, 1e-6]:
245245
EmbeddingAllReduceRMSNormPattern(
246-
epsilon, self.dtype, self.device).register(self.patterns)
246+
epsilon, self.model_dtype, self.device).register(self.patterns)
247247

248-
MiddleAllReduceRMSNormPattern(epsilon, self.dtype,
248+
MiddleAllReduceRMSNormPattern(epsilon, self.model_dtype,
249249
self.device).register(self.patterns)
250250

251-
LastAllReduceRMSNormPattern(epsilon, self.dtype,
251+
LastAllReduceRMSNormPattern(epsilon, self.model_dtype,
252252
self.device).register(self.patterns)
253253
# WARNING: This is a hack to clear the pattern matcher cache
254254
# and allow multiple values of epsilon.

vllm/compilation/vllm_inductor_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ class VllmInductorPass(InductorPass):
2626

2727
def __init__(self, config: VllmConfig):
2828
self.pass_config = config.compilation_config.pass_config
29-
self.dtype = config.model_config.dtype if config.model_config else None
29+
self.model_dtype = config.model_config.dtype if config.model_config \
30+
else None
3031
self.device = config.device_config.device if config.device_config \
3132
else None
3233
self.pass_name = self.__class__.__name__

0 commit comments

Comments
 (0)