Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,57 @@ def test_tts_config(self, loaded_target_classes) -> None:
assert config.use_parallel_embedding is False
assert config.return_proj_buf is False
assert config.sampling_mode == "per_call"

def test_prefix_graph_config_helpers(self, loaded_target_classes) -> None:
"""Prefix graph helpers parse deploy config values and keep valid seq lens only."""
_ = loaded_target_classes
common_mod = sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"]
wrapper_cls = common_mod.CodePredictorWrapper

assert wrapper_cls._parse_positive_int_set("64; 128,0,-1") == {
64,
128,
}
assert wrapper_cls._parse_positive_int_set([2, "4", 0]) == {2, 4}
with pytest.raises(ValueError, match="Invalid positive int config value 'bad'"):
wrapper_cls._parse_positive_int_set("2,bad")

wrapper = object.__new__(wrapper_cls)
wrapper._prefix_graph_seq_lens = {1, 2, 4, 8, 99}
assert wrapper._prefix_seq_lens(6) == [2, 4]

def test_prefix_graph_env_requires_cuda_graphs(
self,
mocker: MockerFixture,
loaded_target_classes,
) -> None:
"""Avoid prefix warmup on shared code-predictor users that disable CUDA graphs."""
_ = loaded_target_classes
common_mod = sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"]
mocker.patch.object(common_mod.current_omni_platform, "is_npu", return_value=False)

cp_config, _ = _make_tiny_config(loaded_target_classes)
vllm_config = _make_vllm_config(mocker, max_num_seqs=2)
vllm_config.model_config.stage_connector_config = {
"extra": {
"code_predictor_prefix_graphs": True,
"code_predictor_prefix_graph_buckets": [2],
"code_predictor_prefix_graph_seq_lens": "2,3",
}
}

no_graph_wrapper = common_mod.CodePredictorWrapper(
vllm_config=vllm_config,
cp_config=cp_config,
wrapper_config=common_mod.CodePredictorWrapperConfig(use_cuda_graphs=False),
)
assert no_graph_wrapper._prefix_graphs_enabled is False
assert no_graph_wrapper._prefix_graph_buckets == {2}
assert no_graph_wrapper._prefix_graph_seq_lens == {2, 3}

graph_wrapper = common_mod.CodePredictorWrapper(
vllm_config=vllm_config,
cp_config=cp_config,
wrapper_config=common_mod.CodePredictorWrapperConfig(use_cuda_graphs=True),
)
assert graph_wrapper._prefix_graphs_enabled is True
129 changes: 124 additions & 5 deletions tests/model_executor/models/qwen3_tts/test_cuda_graph_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,16 @@ def wrapper(decoder):
w = CUDAGraphDecoderWrapper(
decoder=decoder,
capture_sizes=[25, 50, 100],
capture_batch_sizes=[1, 2],
num_quantizers=NUM_QUANTIZERS,
enabled=True,
)
w.warmup(DEVICE)
return w


def _random_codes(seq_len, device=DEVICE):
return torch.randint(0, 100, (1, NUM_QUANTIZERS, seq_len), dtype=torch.long, device=device)
def _random_codes(seq_len, batch_size=1, device=DEVICE):
return torch.randint(0, 100, (batch_size, NUM_QUANTIZERS, seq_len), dtype=torch.long, device=device)


# ──────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -213,6 +214,44 @@ def test_chunked_decode_exact_size_equivalence(decoder, wrapper, total_len):
torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0)


def test_chunked_decode_output_survives_later_replay(wrapper):
"""Chunked output must not alias graph static buffers overwritten by later replays."""
codes = _random_codes(100)
overwrite_codes = _random_codes(100)

with torch.no_grad():
graph_out = wrapper.chunked_decode_with_cudagraph(codes, chunk_size=50, left_context_size=0)
expected = graph_out.clone()
_ = wrapper.decode(overwrite_codes[..., :50])
_ = wrapper.decode(overwrite_codes)

torch.testing.assert_close(graph_out, expected, atol=0, rtol=0)


def test_batched_chunked_decode_variable_lengths_matches_per_request_eager(decoder, wrapper):
"""Variable-length chunk batching should match independent chunked decodes."""
long_codes = _random_codes(100)
short_codes = _random_codes(50)
padded_codes = torch.zeros(2, NUM_QUANTIZERS, 100, dtype=torch.long, device=DEVICE)
padded_codes[0, :, :] = long_codes[0]
padded_codes[1, :, :50] = short_codes[0]

with torch.no_grad():
eager_long = _eager_chunked(decoder, long_codes, chunk_size=50, left_context_size=0)
eager_short = _eager_chunked(decoder, short_codes, chunk_size=50, left_context_size=0)
graph_out = wrapper.batched_chunked_decode_with_cudagraph(
padded_codes,
[100, 50],
chunk_size=50,
left_context_size=0,
max_batch_size=2,
)

assert graph_out.shape == (2, 1, 100 * TOTAL_UPSAMPLE)
torch.testing.assert_close(graph_out[0:1], eager_long, atol=1e-6, rtol=1e-6)
torch.testing.assert_close(graph_out[1:2, :, : 50 * TOTAL_UPSAMPLE], eager_short, atol=1e-6, rtol=1e-6)


def _eager_chunked(decoder, codes, chunk_size, left_context_size):
"""Eager chunked decode matching the real decoder's chunked_decode logic."""
wavs = []
Expand Down Expand Up @@ -254,15 +293,95 @@ def test_disabled_wrapper_matches_eager(decoder, wrapper):
torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0)


def test_batch_size_gt1_falls_back(decoder, wrapper):
"""Batch size > 1 should fall back to eager (bit-identical)."""
codes = torch.randint(0, 100, (2, NUM_QUANTIZERS, 25), dtype=torch.long, device=DEVICE)
def test_batch_size_gt1_uses_matching_graph(decoder, wrapper):
"""Captured batch size > 1 should replay a matching graph."""
assert (2, 25) in wrapper.graphs
codes = _random_codes(25, batch_size=2)
with torch.no_grad():
eager_out = decoder(codes)
graph_out = wrapper.decode(codes)
torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0)


def test_uncaptured_batch_size_falls_back(decoder, wrapper):
"""Uncaptured batch sizes should fall back to eager."""
assert (3, 25) not in wrapper.graphs
codes = _random_codes(25, batch_size=3)
with torch.no_grad():
eager_out = decoder(codes)
graph_out = wrapper.decode(codes)
torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0)


def test_extra_capture_shape_uses_sparse_graph(decoder):
"""Extra capture shapes should not expand to a full batch x size product."""
sparse_wrapper = CUDAGraphDecoderWrapper(
decoder=decoder,
capture_sizes=[25],
capture_batch_sizes=[1],
extra_capture_shapes=[(2, 50)],
num_quantizers=NUM_QUANTIZERS,
enabled=True,
)
sparse_wrapper.warmup(DEVICE)

assert (1, 25) in sparse_wrapper.graphs
assert (2, 50) in sparse_wrapper.graphs
assert (2, 25) not in sparse_wrapper.graphs

codes = _random_codes(50, batch_size=2)
with torch.no_grad():
eager_out = decoder(codes)
graph_out = sparse_wrapper.decode(codes)
torch.testing.assert_close(graph_out, eager_out, atol=0, rtol=0)


def test_compile_shape_supports_exact_and_padded_buckets(decoder, monkeypatch):
"""Configured torch.compile shapes should replay exact and padded CUDA Graph buckets."""

compile_kwargs = {}

def _fake_compile(model, **_kwargs):
compile_kwargs.update(_kwargs)

def _compiled(codes):
return model(codes) + 0.125

return _compiled

monkeypatch.setattr(torch, "compile", _fake_compile)

compiled_wrapper = CUDAGraphDecoderWrapper(
decoder=decoder,
capture_sizes=[25, 50],
capture_batch_sizes=[1],
compile_shapes=[(1, 25), (1, 50)],
num_quantizers=NUM_QUANTIZERS,
enabled=True,
)
compiled_wrapper.warmup(DEVICE)

exact_codes = _random_codes(25)
padded_codes = _random_codes(30)
uncaptured_codes = _random_codes(60)
padded_static = torch.zeros(1, NUM_QUANTIZERS, 50, dtype=torch.long, device=DEVICE)
padded_static[:, :, :30] = padded_codes
with torch.no_grad():
exact_eager = decoder(exact_codes)
exact_out = compiled_wrapper.decode(exact_codes)
padded_graph_expected = decoder(padded_static)[..., : 30 * TOTAL_UPSAMPLE]
padded_out = compiled_wrapper.decode(padded_codes)
uncaptured_eager = decoder(uncaptured_codes)
uncaptured_out = compiled_wrapper.decode(uncaptured_codes)

torch.testing.assert_close(exact_out, exact_eager + 0.125, atol=0, rtol=0)
torch.testing.assert_close(padded_out, padded_graph_expected + 0.125, atol=0, rtol=0)
torch.testing.assert_close(uncaptured_out, uncaptured_eager, atol=0, rtol=0)
assert compile_kwargs["mode"] == "default"
assert compile_kwargs["fullgraph"] is False
assert compile_kwargs["dynamic"] is False


def test_deterministic_across_calls(decoder, wrapper):
"""Same input should produce identical CUDA graph output across calls."""
codes = _random_codes(30)
Expand Down
Loading
Loading