Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions tests/diffusion/attention/test_piecewise_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""End-to-end test for ``piecewise_attn`` (CPU).

Verify that running attention in segments (causal outside full-attn spans,
bidirectional inside full-attn spans) matches running a single full SDPA call
with the equivalent 2D attention mask.

Covers:
* batch size = 1 and batch size > 1 (homogeneous CFG-like batch)
* query length == key length (full prefill)
* query length < key length (decode-like tail slice)
* various full-attn-span layouts (none / start / middle / end / multi)
"""

from __future__ import annotations

import pytest
import torch
import torch.nn.functional as F

from vllm_omni.diffusion.attention.backends.utils.piecewise_attn import (
piecewise_attn,
)

DEVICE = torch.device("cpu")


def _sdpa_attn_func(q, k, v, causal, softmax_scale):
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
attn_mask = None
if causal:
Sq, Sk = q_.shape[-2], k_.shape[-2]
i = torch.arange(Sq, device=q.device).unsqueeze(1)
j = torch.arange(Sk, device=q.device).unsqueeze(0)
attn_mask = j <= (i + (Sk - Sq))
out = F.scaled_dot_product_attention(q_, k_, v_, attn_mask=attn_mask, scale=softmax_scale)
return out.transpose(1, 2).contiguous()


def _full_reference(query, key, value, global_spans, q_start, q_end, softmax_scale):
"""Build a full 2D mask with global spans and compute reference output."""
Sk = key.shape[1]
mask = torch.tril(torch.ones(Sk, Sk, dtype=torch.bool, device=key.device))
for a, e in global_spans:
mask[a:e, :e] = True
mask_q = mask[q_start:q_end, :]
q_ = query.transpose(1, 2)
k_ = key.transpose(1, 2)
v_ = value.transpose(1, 2)
out = F.scaled_dot_product_attention(q_, k_, v_, attn_mask=mask_q, scale=softmax_scale)
return out.transpose(1, 2).contiguous()


SPAN_CASES = [
pytest.param([], id="no-spans"),
pytest.param([(0, 10)], id="span-at-start"),
pytest.param([(10, 30), (54, 64)], id="multi-spans"),
]

Q_RANGE_CASES = [
pytest.param((0, 64), id="q_eq_k"), # Sq == Sk (prefill)
pytest.param((53, 64), id="q_lt_k"), # Sq < Sk (decode-like)
]

BATCH_CASES = [
pytest.param(1, id="B1"),
pytest.param(2, id="B2"),
]


@pytest.mark.parametrize("global_spans", SPAN_CASES)
@pytest.mark.parametrize("q_range", Q_RANGE_CASES)
@pytest.mark.parametrize("batch_size", BATCH_CASES)
def test_piecewise_matches_full(global_spans, q_range, batch_size):
torch.manual_seed(0)
H, D, Sk = 2, 16, 64
q_start, q_end = q_range
Sq = q_end - q_start

key = torch.randn(batch_size, Sk, H, D, device=DEVICE)
value = torch.randn(batch_size, Sk, H, D, device=DEVICE)
query = torch.randn(batch_size, Sq, H, D, device=DEVICE)

full_attn_spans = [list(global_spans) for _ in range(batch_size)]
softmax_scale = 1.0 / (D**0.5)

got = piecewise_attn(
query,
key,
value,
full_attn_spans=full_attn_spans,
softmax_scale=softmax_scale,
attn_func=_sdpa_attn_func,
)
expected = _full_reference(query, key, value, global_spans, q_start, q_end, softmax_scale)
torch.testing.assert_close(got, expected, atol=1e-5, rtol=1e-5)


def test_piecewise_span_fully_before_qstart():
"""Spans entirely before query region produce pure causal attention."""
torch.manual_seed(0)
B, H, D, Sk = 1, 2, 16, 30
q_start, q_end = 15, 30
Sq = q_end - q_start

key = torch.randn(B, Sk, H, D, device=DEVICE)
value = torch.randn(B, Sk, H, D, device=DEVICE)
query = torch.randn(B, Sq, H, D, device=DEVICE)

global_spans = [(5, 10)]
full_attn_spans = [list(global_spans) for _ in range(B)]
softmax_scale = 1.0 / (D**0.5)

got = piecewise_attn(
query,
key,
value,
full_attn_spans=full_attn_spans,
softmax_scale=softmax_scale,
attn_func=_sdpa_attn_func,
)
expected = _full_reference(query, key, value, global_spans, q_start, q_end, softmax_scale)
torch.testing.assert_close(got, expected, atol=1e-5, rtol=1e-5)
9 changes: 6 additions & 3 deletions tests/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,16 +393,19 @@ def test_nullify_stage_engine_defaults_resets_inherited_defaults():
def test_non_override_flags_keep_real_defaults_after_nullify():
import argparse

from vllm_omni.config.stage_config import deploy_override_field_names
from vllm_omni.engine.arg_utils import nullify_stage_engine_defaults

parser = argparse.ArgumentParser()
parser.add_argument("--hsdp-shard-size", type=int, default=-1, help="HSDP shard size.")
parser.add_argument("--batch-timeout", type=int, default=10, help="Batch timeout.")
parser.add_argument("--max-num-seqs", type=int, default=64, help="Max num seqs.")
nullify_stage_engine_defaults(parser)

hsdp = next(a for a in parser._actions if a.dest == "hsdp_shard_size")
assert "batch_timeout" not in deploy_override_field_names()

batch_timeout = next(a for a in parser._actions if a.dest == "batch_timeout")
max_num_seqs = next(a for a in parser._actions if a.dest == "max_num_seqs")
assert hsdp.default == -1
assert batch_timeout.default == 10
assert max_num_seqs.default is None


Expand Down
171 changes: 163 additions & 8 deletions tests/test_config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,147 @@ def test_to_omegaconf_omits_none_deploy_overrides_for_engine_args(self):
for name in deploy_override_field_names() - {"devices"}:
assert name not in engine_args

def test_to_omegaconf_diffusion_parallel_overrides_replace_nested_values(self):
config = StageConfig(
stage_id=1,
model_stage="diffusion",
stage_type=StageType.DIFFUSION,
yaml_engine_args={
"parallel_config": {
"pipeline_parallel_size": 1,
"data_parallel_size": 1,
"tensor_parallel_size": 4,
"enable_expert_parallel": False,
"ulysses_degree": 1,
"ring_degree": 1,
"ulysses_mode": "strict",
"sequence_parallel_size": 1,
"cfg_parallel_size": 1,
"vae_patch_parallel_size": 1,
"use_hsdp": False,
"hsdp_shard_size": -1,
"hsdp_replicate_size": 1,
}
},
runtime_overrides={
"pipeline_parallel_size": 2,
"data_parallel_size": 3,
"tensor_parallel_size": 8,
"enable_expert_parallel": True,
"ulysses_degree": 2,
"ring_degree": 4,
"ulysses_mode": "advanced_uaa",
"sequence_parallel_size": 8,
"cfg_parallel_size": 2,
"vae_patch_parallel_size": 2,
"use_hsdp": True,
"hsdp_shard_size": 8,
"hsdp_replicate_size": 2,
},
)

omega_config = config.to_omegaconf()

assert omega_config.engine_args.parallel_config.pipeline_parallel_size == 2
assert omega_config.engine_args.parallel_config.data_parallel_size == 3
assert omega_config.engine_args.parallel_config.tensor_parallel_size == 8
assert omega_config.engine_args.parallel_config.enable_expert_parallel is True
assert omega_config.engine_args.parallel_config.ulysses_degree == 2
assert omega_config.engine_args.parallel_config.ring_degree == 4
assert omega_config.engine_args.parallel_config.ulysses_mode == "advanced_uaa"
assert omega_config.engine_args.parallel_config.sequence_parallel_size == 8
assert omega_config.engine_args.parallel_config.cfg_parallel_size == 2
assert omega_config.engine_args.parallel_config.vae_patch_parallel_size == 2
assert omega_config.engine_args.parallel_config.use_hsdp is True
assert omega_config.engine_args.parallel_config.hsdp_shard_size == 8
assert omega_config.engine_args.parallel_config.hsdp_replicate_size == 2
assert "pipeline_parallel_size" not in omega_config.engine_args
assert "data_parallel_size" not in omega_config.engine_args
assert "tensor_parallel_size" not in omega_config.engine_args
assert "enable_expert_parallel" not in omega_config.engine_args
assert "ulysses_degree" not in omega_config.engine_args
assert "ring_degree" not in omega_config.engine_args
assert "ulysses_mode" not in omega_config.engine_args
assert "sequence_parallel_size" not in omega_config.engine_args
assert "cfg_parallel_size" not in omega_config.engine_args
assert "vae_patch_parallel_size" not in omega_config.engine_args
assert "use_hsdp" not in omega_config.engine_args
assert "hsdp_shard_size" not in omega_config.engine_args
assert "hsdp_replicate_size" not in omega_config.engine_args

def test_to_omegaconf_diffusion_parallel_overrides_create_parallel_config(self):
config = StageConfig(
stage_id=1,
model_stage="diffusion",
stage_type=StageType.DIFFUSION,
runtime_overrides={
"pipeline_parallel_size": 2,
"data_parallel_size": 3,
"tensor_parallel_size": 8,
"enable_expert_parallel": True,
"ulysses_degree": 2,
"ring_degree": 4,
"ulysses_mode": "advanced_uaa",
"sequence_parallel_size": 8,
"cfg_parallel_size": 2,
"vae_patch_parallel_size": 2,
"use_hsdp": True,
"hsdp_shard_size": 8,
"hsdp_replicate_size": 2,
},
)

omega_config = config.to_omegaconf()

assert omega_config.engine_args.parallel_config.pipeline_parallel_size == 2
assert omega_config.engine_args.parallel_config.data_parallel_size == 3
assert omega_config.engine_args.parallel_config.tensor_parallel_size == 8
assert omega_config.engine_args.parallel_config.enable_expert_parallel is True
assert omega_config.engine_args.parallel_config.ulysses_degree == 2
assert omega_config.engine_args.parallel_config.ring_degree == 4
assert omega_config.engine_args.parallel_config.ulysses_mode == "advanced_uaa"
assert omega_config.engine_args.parallel_config.sequence_parallel_size == 8
assert omega_config.engine_args.parallel_config.cfg_parallel_size == 2
assert omega_config.engine_args.parallel_config.vae_patch_parallel_size == 2
assert omega_config.engine_args.parallel_config.use_hsdp is True
assert omega_config.engine_args.parallel_config.hsdp_shard_size == 8
assert omega_config.engine_args.parallel_config.hsdp_replicate_size == 2
assert "pipeline_parallel_size" not in omega_config.engine_args
assert "data_parallel_size" not in omega_config.engine_args
assert "tensor_parallel_size" not in omega_config.engine_args
assert "enable_expert_parallel" not in omega_config.engine_args
assert "ulysses_degree" not in omega_config.engine_args
assert "ring_degree" not in omega_config.engine_args
assert "ulysses_mode" not in omega_config.engine_args
assert "sequence_parallel_size" not in omega_config.engine_args
assert "cfg_parallel_size" not in omega_config.engine_args
assert "vae_patch_parallel_size" not in omega_config.engine_args
assert "use_hsdp" not in omega_config.engine_args
assert "hsdp_shard_size" not in omega_config.engine_args
assert "hsdp_replicate_size" not in omega_config.engine_args

def test_to_omegaconf_llm_parallel_overrides_remain_top_level(self):
config = StageConfig(
stage_id=0,
model_stage="thinker",
stage_type=StageType.LLM,
runtime_overrides={
"pipeline_parallel_size": 2,
"data_parallel_size": 3,
"tensor_parallel_size": 8,
},
)

omega_config = config.to_omegaconf()

assert omega_config.engine_args.pipeline_parallel_size == 2
assert omega_config.engine_args.data_parallel_size == 3
assert omega_config.engine_args.tensor_parallel_size == 8
assert "pipeline_parallel_size" in omega_config.engine_args
assert "data_parallel_size" in omega_config.engine_args
assert "tensor_parallel_size" in omega_config.engine_args
assert "parallel_config" not in omega_config.engine_args


class TestModelPipeline:
"""Tests for ModelPipeline class."""
Expand Down Expand Up @@ -830,31 +971,45 @@ def test_deploy_override_fields_include_deploy_schema_fields(self):

expected_fields = {
"async_chunk",
# StageDeployConfig: stage placement and runtime fields.
"devices",
# StageDeployConfig: vLLM EngineArgs fields.
"async_scheduling",
"compilation_config",
"config_format",
"data_parallel_size",
"devices",
"disable_hybrid_kv_cache_manager",
"distributed_executor_backend",
"dtype",
"enable_chunked_prefill",
"enable_flashinfer_autotune",
"enable_prefix_caching",
"enforce_eager",
"gpu_memory_utilization",
"load_format",
"max_model_len",
"max_num_batched_tokens",
"max_num_seqs",
"mm_processor_cache_gb",
"pipeline_parallel_size",
"profiler_config",
"quantization",
"skip_mm_profiling",
"subtalker_sampling_params",
"tensor_parallel_size",
"tokenizer_mode",
# StageDeployConfig: diffusion parallel_config deploy override fields.
"cfg_parallel_size",
"enable_expert_parallel",
"hsdp_replicate_size",
"hsdp_shard_size",
"ring_degree",
"sequence_parallel_size",
"ulysses_degree",
"ulysses_mode",
"use_hsdp",
"vae_patch_parallel_size",
# DeployConfig: pipeline-wide engine settings.
"data_parallel_size",
"distributed_executor_backend",
"dtype",
"enable_chunked_prefill",
"enable_prefix_caching",
"pipeline_parallel_size",
"quantization",
"trust_remote_code",
}

Expand Down
Loading
Loading