Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
350 changes: 350 additions & 0 deletions docs/design/dcp_communication_patterns.md

Large diffs are not rendered by default.

26 changes: 17 additions & 9 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
cp_kv_cache_interleave_size: int
pcp_size: int
dcp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool

Expand All @@ -73,7 +74,8 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_multipliers: list[float] | None = None,
cp_kv_cache_interleave_size: int = 1,
pcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
attn_backend: str | None = None,
Expand All @@ -91,8 +93,9 @@ def detailed(
ParallelSetup(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
dcp_size=max(1, int(dcp_multiplier * tp_base)),
pcp_size=pcp_base,
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
Expand Down Expand Up @@ -126,16 +129,18 @@ def iter_params(self, model_id: str):
CPTestSettings.detailed(dcp_multipliers=[1]),
CPTestSettings.detailed(
dcp_multipliers=[0.5],
cp_kv_cache_interleave_size=64,
dcp_kv_cache_interleave_size=64,
attn_backend="FLASHMLA",
),
CPTestSettings.detailed(tp_base=1, pcp_base=4, dcp_kv_cache_interleave_size=64),
CPTestSettings.detailed(tp_base=2, pcp_base=2, dcp_kv_cache_interleave_size=64),
],
"Qwen/Qwen2.5-1.5B-Instruct": [
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
dcp_kv_cache_interleave_size=16, attn_backend="FLASH_ATTN"
),
CPTestSettings.detailed(
cp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
dcp_kv_cache_interleave_size=16, attn_backend="FLASHINFER"
),
],
}
Expand All @@ -156,7 +161,8 @@ def _test_cp_gsm8k(
tp_size,
pp_size,
dcp_size,
cp_kv_cache_interleave_size,
pcp_size,
dcp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
Expand Down Expand Up @@ -212,8 +218,10 @@ def _test_cp_gsm8k(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--prefill-context-parallel-size",
str(pcp_size),
"--dcp-kv-cache-interleave-size",
str(cp_kv_cache_interleave_size),
str(dcp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
Expand Down
192 changes: 192 additions & 0 deletions tests/distributed/test_dcp_a2a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for DCP A2A communication backend (no GPU required).

Tests cover:
1. DCP A2A config validation (--dcp-comm-backend)
2. KVP group function exists
3. LSE-weighted combination correctness
"""

import math

import pytest
import torch

from vllm.config.parallel import ParallelConfig


class TestDCPCommBackendConfig:
"""Test --dcp-comm-backend config validation."""

def test_default_is_ag_rs(self):
"""Default comm backend is ag_rs."""
config = ParallelConfig()
assert config.dcp_comm_backend == "ag_rs"

def test_a2a_requires_dcp_greater_than_1(self):
"""A2A backend requires decode_context_parallel_size > 1."""
with pytest.raises(
ValueError, match="requires decode_context_parallel_size > 1"
):
ParallelConfig(
dcp_comm_backend="a2a",
decode_context_parallel_size=1,
)

def test_a2a_with_dcp_valid(self):
"""A2A backend is valid when DCP > 1."""
config = ParallelConfig(
dcp_comm_backend="a2a",
tensor_parallel_size=8,
decode_context_parallel_size=4,
)
assert config.dcp_comm_backend == "a2a"

def test_invalid_backend_rejected(self):
"""Invalid backend values are rejected."""
with pytest.raises(ValueError, match="must be one of"):
ParallelConfig(
dcp_comm_backend="invalid",
)

def test_ag_rs_with_dcp_1_valid(self):
"""ag_rs backend is valid with DCP=1 (no DCP)."""
config = ParallelConfig(
dcp_comm_backend="ag_rs",
decode_context_parallel_size=1,
)
assert config.dcp_comm_backend == "ag_rs"


class TestLSEWeightedCombine:
"""Test LSE-weighted combination logic (CPU only, no GPU).

The _lse_weighted_combine function is the reference implementation
that verifies the Triton kernel's correctness. It computes:

result[b,h,d] = sum_n(w_n * output_n[b,h,d])

where w_n = softmax(lse_n) = exp(lse_n) / sum_k(exp(lse_k))
"""

def test_importable(self):
"""Verify _lse_weighted_combine is importable."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine

assert callable(_lse_weighted_combine)

def test_single_rank(self):
"""Single rank: output unchanged."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine

# N=1, B=2, H=4, D=8
outputs = torch.randn(1, 2, 4, 8)
lses = torch.randn(1, 2, 4)

result = _lse_weighted_combine(outputs, lses)

assert result.shape == (2, 4, 8)
torch.testing.assert_close(result, outputs.squeeze(0), rtol=1e-5, atol=1e-5)

def test_equal_lse(self):
"""Equal LSE values: outputs averaged equally."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine

_N, B, H, D = 2, 1, 1, 4
outputs = torch.tensor(
[
[[[1.0, 2.0, 3.0, 4.0]]], # Rank 0
[[[5.0, 6.0, 7.0, 8.0]]], # Rank 1
]
)
lses = torch.tensor(
[
[[0.0]], # Rank 0
[[0.0]], # Rank 1
]
)

result = _lse_weighted_combine(outputs, lses)

expected = (outputs[0] + outputs[1]) / 2
assert result.shape == (B, H, D)
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)

def test_dominant_rank(self):
"""Different LSE values: larger LSE gets more weight."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine

B, H, D = 1, 1, 2
outputs = torch.tensor(
[
[[[0.0, 0.0]]], # Rank 0
[[[1.0, 1.0]]], # Rank 1
]
)
lses = torch.tensor(
[
[[-100.0]], # Rank 0: negligible contribution
[[0.0]], # Rank 1: dominant
]
)

result = _lse_weighted_combine(outputs, lses)

assert result.shape == (B, H, D)
torch.testing.assert_close(result, outputs[1].squeeze(0), atol=1e-5, rtol=1e-5)

def test_mathematically_correct(self):
"""Verify mathematical correctness of LSE combination."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine

outputs = torch.tensor(
[
[[[2.0, 4.0]]],
[[[6.0, 8.0]]],
]
)
lses = torch.tensor(
[
[[1.0]], # exp(1) ≈ 2.718
[[2.0]], # exp(2) ≈ 7.389
]
)

result = _lse_weighted_combine(outputs, lses)

w0 = math.exp(1) / (math.exp(1) + math.exp(2))
w1 = math.exp(2) / (math.exp(1) + math.exp(2))
expected = torch.tensor([[[w0 * 2.0 + w1 * 6.0, w0 * 4.0 + w1 * 8.0]]])

torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4)

def test_return_lse(self):
"""return_lse=True returns global LSE (logsumexp of inputs)."""
from vllm.v1.attention.ops.dcp_alltoall import _lse_weighted_combine

B, H, D = 1, 1, 2
outputs = torch.tensor(
[
[[[1.0, 2.0]]],
[[[3.0, 4.0]]],
]
)
lses = torch.tensor(
[
[[1.0]],
[[2.0]],
]
)

result, global_lse = _lse_weighted_combine(outputs, lses, return_lse=True)

expected_global_lse = math.log(math.exp(1) + math.exp(2))

assert result.shape == (B, H, D)
assert global_lse.shape == (B, H)
assert abs(global_lse.item() - expected_global_lse) < 1e-5


if __name__ == "__main__":
pytest.main([__file__, "-v"])
18 changes: 17 additions & 1 deletion tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_pre_quantized_model(vllm_runner):
with vllm_runner(
"torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.15.0",
"drisspg/fp8-opt-125m",
quantization="torchao",
dtype="bfloat16",
enforce_eager=True,
Expand Down Expand Up @@ -52,6 +52,22 @@ def test_opt_125m_int8wo_model_loading_with_params(vllm_runner, pt_load_map_loca
assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo-per-module"
with vllm_runner(
model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0",
enforce_eager=True,
) as llm:
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)

assert output


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
Expand Down
Loading