Skip to content
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG,
OptimizationLevel,
enable_allreduce_rms_fusion,
)
from vllm.platforms import current_platform

Expand Down Expand Up @@ -58,6 +59,26 @@ def test_async_scheduling_with_pipeline_parallelism_is_allowed():
assert cfg.scheduler_config.async_scheduling is True


def test_enable_allreduce_rms_fusion_disabled_for_pp():
cfg = VllmConfig(
parallel_config=ParallelConfig(
tensor_parallel_size=2,
pipeline_parallel_size=1,
data_parallel_size=1,
)
)

with (
patch("vllm.utils.flashinfer.has_flashinfer", return_value=True),
patch.object(current_platform, "is_cuda", return_value=True),
patch.object(current_platform, "is_device_capability", return_value=True),
):
assert enable_allreduce_rms_fusion(cfg)

cfg.parallel_config.pipeline_parallel_size = 2
assert not enable_allreduce_rms_fusion(cfg)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current test correctly covers the pipeline parallelism case as intended. However, to make it more robust and prevent future regressions, it would be beneficial to cover all gating conditions within enable_allreduce_rms_fusion, including tensor_parallel_size and data_parallel_size.

A parameterized test would be a clean way to test all ParallelConfig combinations and improve the test's clarity and maintainability.

@pytest.mark.parametrize(
    ("parallel_config", "should_be_enabled"),
    [
        # Should be enabled with only TP > 1
        (ParallelConfig(tensor_parallel_size=2, pipeline_parallel_size=1, data_parallel_size=1), True),
        # Should be disabled with TP <= 1
        (ParallelConfig(tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1), False),
        # Should be disabled with PP > 1
        (ParallelConfig(tensor_parallel_size=2, pipeline_parallel_size=2, data_parallel_size=1), False),
        # Should be disabled with DP > 1
        (ParallelConfig(tensor_parallel_size=2, pipeline_parallel_size=1, data_parallel_size=2), False),
    ],
    ids=["TP-only", "No-TP", "With-PP", "With-DP"]
)
def test_enable_allreduce_rms_fusion_gating(parallel_config, should_be_enabled):
    cfg = VllmConfig(parallel_config=parallel_config)

    with (
        patch("vllm.utils.flashinfer.has_flashinfer", return_value=True),
        patch.object(current_platform, "is_cuda", return_value=True),
        patch.object(current_platform, "is_device_capability", return_value=True),
    ):
        assert enable_allreduce_rms_fusion(cfg) is should_be_enabled



@dataclass
class _TestConfigFields:
a: int
Expand Down