diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9c200a577167..663980ddbb6c 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -399,8 +399,6 @@ steps: - pytest -v -s compile/test_fusion_attn.py - pytest -v -s compile/test_functionalization.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - - pytest -v -s compile/test_sequence_parallelism.py - - pytest -v -s compile/test_async_tp.py - pytest -v -s compile/test_fusion_all_reduce.py - pytest -v -s compile/test_decorator.py - pytest -v -s compile/test_noop_elimination.py @@ -1081,6 +1079,8 @@ steps: working_dir: "/vllm-workspace/" num_gpus: 2 commands: + - pytest -v -s tests/compile/test_async_tp.py + - pytest -v -s tests/compile/test_sequence_parallelism.py - pytest -v -s tests/distributed/test_context_parallel.py - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 5860833c14ce..d8a4c112ecb1 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -169,15 +169,23 @@ def replacement( scale_a: torch.Tensor, scale_b: torch.Tensor, ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs @@ -296,15 +304,23 @@ def replacement( scale_b: torch.Tensor, cutlass_mm_output: torch.Tensor, ) -> torch.Tensor: + # Calculate output shape: input @ mat2 with scatter_dim reduced + output_shape = [*input.shape[:-1], mat2.shape[1]] + scatter_dim = 0 gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter( input, mat2, scale_a, scale_b, "avg", - scatter_dim=0, - out_dtype=self.dtype, - group_name=self.tp.device_group.group_name, + scatter_dim, # orig_scatter_dim + scatter_dim, # scatter_dim_after_maybe_reshape + self.tp.device_group.group_name, + output_shape, + None, # bias + None, # result_scale + self.dtype, # out_dtype + False, # use_fast_accum ) return gemm_rs