77
88import vllm .envs as envs
99from vllm .compilation .collective_fusion import AsyncTPPass
10- from vllm .compilation .fx_utils import (find_specified_fn ,
11- find_specified_fn_maybe )
1210from vllm .config import (CompilationConfig , DeviceConfig , ModelConfig ,
1311 PassConfig , VllmConfig )
1412from vllm .distributed import (tensor_model_parallel_all_gather ,
@@ -93,7 +91,7 @@ def ops_in_model_after(self):
9391
9492
9593@multi_gpu_test (num_gpus = 2 )
96- @pytest .mark .parametrize ("test_model" , [" TestMMRSModel" , " TestAGMMModel" ])
94+ @pytest .mark .parametrize ("test_model" , [TestMMRSModel , TestAGMMModel ])
9795@pytest .mark .parametrize ("batch_size" , [8 ])
9896@pytest .mark .parametrize ("seq_len" , [16 ])
9997@pytest .mark .parametrize ("hidden_size" , [16 ])
@@ -117,7 +115,8 @@ def run_torch_spawn(fn, nprocs):
117115
118116
119117def async_tp_pass_on_test_model (local_rank : int , world_size : int ,
120- test_model : str , batch_size : int , seq_len : int ,
118+ test_model_cls : torch .nn .Module ,
119+ batch_size : int , seq_len : int ,
121120 hidden_size : int , dtype : torch .dtype ):
122121 current_platform .seed_everything (0 )
123122
@@ -158,12 +157,7 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
158157 async_tp_pass = AsyncTPPass (vllm_config )
159158 backend = TestBackend (async_tp_pass )
160159
161- if test_model == "TestMMRSModel" :
162- model = TestMMRSModel (hidden_size )
163- elif test_model == "TestAGMMModel" :
164- model = TestAGMMModel (hidden_size )
165- else :
166- raise ValueError (f"Unknown model: { test_model } " )
160+ model = test_model_cls (hidden_size )
167161
168162 hidden_states = torch .randn ((batch_size * seq_len , hidden_size ),
169163 dtype = dtype ,
@@ -172,21 +166,14 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
172166 compiled_model = torch .compile (model , backend = backend )
173167 compiled_model (hidden_states )
174168
175- # Check substitution worked
176- pre_nodes = backend .graph_pre_pass .nodes
177- post_nodes = backend .graph_post_pass .nodes
178-
179- # In pre-nodes, all reduce should exist,
169+ # In pre-nodes, all gather or reduce scatter should exist,
180170 # fused_matmul_reduce_scatter or fused_all_gather_matmul should not
181- for op in model .ops_in_model_before ():
182- find_specified_fn (pre_nodes , op )
183- for op in model .ops_in_model_after ():
184- assert find_specified_fn_maybe (pre_nodes , op ) is None
171+ backend .check_before_ops (model .ops_in_model_before (),
172+ ops_fully_replaced = False )
185173
186174 # In post-nodes, fused_matmul_reduce_scatter or \
187175 # fused_all_gather_matmul should exist
188- for op in model .ops_in_model_after ():
189- find_specified_fn (post_nodes , op )
176+ backend .check_after_ops (model .ops_in_model_after ())
190177
191178
192179@create_new_process_for_each_test ()
@@ -258,12 +245,9 @@ def test_async_tp_pass_correctness(
258245 "mp" ,
259246 ]
260247
261- try :
262- compare_two_settings (model_id ,
263- aysnc_tp_args ,
264- tp_args ,
265- async_tp_env ,
266- tp_env ,
267- method = "generate" )
268- except Exception :
269- raise
248+ compare_two_settings (model_id ,
249+ aysnc_tp_args ,
250+ tp_args ,
251+ async_tp_env ,
252+ tp_env ,
253+ method = "generate" )
0 commit comments