Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions docs/serving/speech_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -531,16 +531,18 @@ for result in response.json()["results"]:

All items are fanned out to `generate()` concurrently. The engine's stage worker automatically batches them up to the configured `max_batch_size` and queues the rest — no client-side throttling needed.

For best throughput, set both stages' `max_num_seqs` to ≥4 via `--stage-overrides`:
For best throughput, set both stages' `max_num_seqs` above 1 via `--stage-overrides`. On the current Qwen3-TTS CustomVoice benchmark, stage 1 performed best at `max_num_seqs: 10`:

```bash
vllm serve Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice \
--omni --port 8091 --trust-remote-code --enforce-eager \
--stage-overrides '{"0":{"max_num_seqs":4,"gpu_memory_utilization":0.2},
"1":{"max_num_seqs":4,"gpu_memory_utilization":0.2}}'
--stage-overrides '{"0":{"max_num_seqs":10,"gpu_memory_utilization":0.2},
"1":{"max_num_seqs":10,"gpu_memory_utilization":0.2}}'
```

The bundled `qwen3_tts.yaml` uses `max_num_seqs: 1` (single request) on both stages. Bumping to 4 yields roughly 4× throughput on the talker and lets stage 1 batch chunks across in-flight requests.
The bundled `qwen3_tts.yaml` uses a multi-request default and lets stage 1 batch chunks across in-flight requests. For latency-sensitive deployments, avoid forcing stage 1 back to `max_num_seqs: 1`; benchmark before reducing it below `10`.

The bundled config also sets `initial_codec_chunk_frames: 1`. This emits only the first audio chunk early for lower TTFA, then returns to the normal `codec_chunk_frames` window so Code2Wav does not repeatedly decode tiny overlapping chunks.

## Supported Models

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,18 @@ def test_deterministic_across_calls(decoder, wrapper):
[2, 4, 8, 16, 25, 32, 50, 64, 128, 256, 325],
[512],
),
(
{
"codec_chunk_frames": 25,
"codec_left_context_frames": 72,
"decode_chunk_size": 400,
"decode_left_context": 17,
},
[2, 4, 8, 16, 25, 32, 64, 97, 128, 256, 417],
[325, 512],
),
],
ids=["default", "streaming_c33", "streaming_c25"],
ids=["default", "streaming_c33", "streaming_c25", "custom_decode_chunk"],
)
def test_compute_capture_sizes(kwargs, expected_in, not_expected):
"""compute_capture_sizes produces expected sizes capped by max useful size."""
Expand Down
159 changes: 154 additions & 5 deletions tests/model_executor/models/qwen3_tts/test_qwen3_tts_code2wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,33 @@ class _FakeDecoder(nn.Module):
def __init__(self, total_upsample: int = _TOTAL_UPSAMPLE):
super().__init__()
self.total_upsample = total_upsample

def chunked_decode(self, codes: torch.Tensor) -> torch.Tensor:
self.decode_calls: list[dict[str, int]] = []
self.cudagraph_calls: list[dict[str, int | torch.device]] = []

def to(self, *args, **kwargs):
return self

def chunked_decode(
self,
codes: torch.Tensor,
*,
chunk_size: int = 300,
left_context_size: int = 25,
) -> torch.Tensor:
self.decode_calls.append(
{
"chunk_size": chunk_size,
"left_context_size": left_context_size,
}
)
frames = codes.shape[-1]
wav_len = frames * self.total_upsample + 6
wav = torch.arange(wav_len, dtype=torch.float32)
return wav.view(1, 1, -1)

def enable_cudagraph(self, **kwargs):
self.cudagraph_calls.append(kwargs)


def _fake_dec_config():
return SimpleNamespace(
Expand All @@ -38,7 +58,12 @@ def _fake_dec_config():
)


def _make_model() -> Qwen3TTSCode2Wav:
def _make_model(
*,
stage_connector_config=None,
async_chunk: bool = False,
device: torch.device | None = None,
) -> Qwen3TTSCode2Wav:
dec_config = _fake_dec_config()
tok_config = SimpleNamespace(
decoder_config=dec_config,
Expand All @@ -56,13 +81,51 @@ def _make_model() -> Qwen3TTSCode2Wav:
):
model = Qwen3TTSCode2Wav(
vllm_config=SimpleNamespace(
model_config=SimpleNamespace(model="unused"),
device_config=SimpleNamespace(device=torch.device("cpu")),
load_config=SimpleNamespace(),
model_config=SimpleNamespace(
model="unused",
revision=None,
stage_connector_config=stage_connector_config,
async_chunk=async_chunk,
),
device_config=SimpleNamespace(device=device or torch.device("cpu")),
)
)
return model


def _load_weights_noop(model: Qwen3TTSCode2Wav) -> set[str]:
class _FakeModelLoader:
class Source:
def __init__(self, **_: object):
pass

def __init__(self, _load_config: object):
pass

def _get_weights_iterator(self, _source: object):
return iter(())

class _FakeAutoWeightsLoader:
def __init__(self, *_: object, **__: object):
pass

def load_weights(self, _weights: object) -> set[str]:
return {"decoder.fake_weight"}

with (
patch(
"vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav.DefaultModelLoader",
_FakeModelLoader,
),
patch(
"vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code2wav.AutoWeightsLoader",
_FakeAutoWeightsLoader,
),
):
return model.load_weights(iter(()))


def test_forward_trims_context_on_exact_frame_boundaries():
model = _make_model()

Expand All @@ -87,3 +150,89 @@ def test_forward_trims_trailing_padding_without_context():
audio = out.multimodal_outputs["model_outputs"][0]
expected = torch.arange(24, dtype=torch.float32)
torch.testing.assert_close(audio, expected)


def test_connector_codec_chunking_does_not_override_decode_chunking():
model = _make_model(
async_chunk=True,
stage_connector_config={
"extra": {
"codec_chunk_frames": 25,
"codec_left_context_frames": 72,
}
},
)

loaded = _load_weights_noop(model)

assert loaded == {"decoder.fake_weight"}
assert model._decode_chunk_frames == 300
assert model._decode_left_context_frames == 25

model.forward(
input_ids=torch.arange(12, dtype=torch.long),
runtime_additional_information=[{"meta": {"left_context_size": 0}}],
)

assert model.decoder.decode_calls[-1] == {
"chunk_size": 300,
"left_context_size": 25,
}


def test_decode_chunking_can_be_overridden_separately():
model = _make_model(
async_chunk=True,
stage_connector_config={
"extra": {
"codec_chunk_frames": 25,
"codec_left_context_frames": 72,
"decode_chunk_frames": 400,
"decode_left_context_frames": 17,
}
},
)

_load_weights_noop(model)

assert model._decode_chunk_frames == 400
assert model._decode_left_context_frames == 17


def test_decode_chunking_override_is_passed_to_cudagraph():
model = _make_model(
async_chunk=True,
device=torch.device("cuda"),
stage_connector_config={
"extra": {
"codec_chunk_frames": 25,
"codec_left_context_frames": 72,
"decode_chunk_frames": 400,
"decode_left_context_frames": 17,
}
},
)

_load_weights_noop(model)

assert model.decoder.cudagraph_calls[-1] == {
"device": torch.device("cuda"),
"codec_chunk_frames": 25,
"codec_left_context_frames": 72,
"decode_chunk_size": 400,
"decode_left_context": 17,
}


def test_invalid_decode_chunking_is_rejected():
model = _make_model(
async_chunk=True,
stage_connector_config={
"extra": {
"decode_chunk_frames": 0,
}
},
)

with pytest.raises(ValueError, match="decode_chunk_frames=0"):
_load_weights_noop(model)
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _req(rid, *, finished, initial_codec_chunk_frames=None):
)


def _tm(*, chunk_frames=25, left_context=25, max_num_seqs=1):
def _tm(*, chunk_frames=25, left_context=25, max_num_seqs=1, initial_chunk_frames=0):
return SimpleNamespace(
code_prompt_token_ids=defaultdict(list),
scheduler_max_num_seqs=max_num_seqs,
Expand All @@ -45,6 +45,7 @@ def _tm(*, chunk_frames=25, left_context=25, max_num_seqs=1):
"extra": {
"codec_chunk_frames": chunk_frames,
"codec_left_context_frames": left_context,
"initial_codec_chunk_frames": initial_chunk_frames,
}
}
),
Expand Down Expand Up @@ -99,48 +100,37 @@ def test_flush_on_finish():

_CASES = [
# ── IC boundary rule ──────────────────────────────────────────────
# IC phase: length <= chunk_size (uses <=, consistent with fish_speech)
# IC emits fill the entire first chunk_size worth of frames, so the
# normal phase always starts at a clean chunk boundary.
# initial_coverage = (chunk_size // initial_chunk_size) * initial_chunk_size
# initial_codec_chunk_frames only controls the first emitted chunk.
# After that, the processor returns to codec_chunk_frames-sized windows
# to avoid flooding Code2Wav with repeated tiny overlapping decodes.
#
# Dynamic IC=16, cs=25, initial_coverage=16
# IC does NOT evenly divide cs, so initial_coverage < cs.
# IC emits at 16; frames 17-25 remain in IC phase but 25%16!=0 -> hold.
# Normal phase: adjusted = length - 16, emit when adjusted % 25 == 0.
((25, 25, 0), 24, False, None), # IC: 24<=25, 24%16!=0 -> hold
((25, 25, 0), 25, False, None), # IC: 25<=25, 25%16!=0 -> hold
((25, 25, 0), 24, False, None),
((25, 25, 0), 25, False, None),
((25, 25, 0), 41, False, (16, 41)), # normal: adjusted=25, 25%25==0 -> emit, lc=16
#
# Per-request IC=10, cs=25, initial_coverage=20
# IC does NOT evenly divide cs; IC emits at 10, 20.
# Frames 21-25 are still IC phase but 21..25 % 10 != 0 -> hold.
((25, 25, 10), 9, False, None), # IC: 9%10!=0 -> hold
((25, 25, 10), 10, False, (0, 10)), # IC: 10%10==0 -> emit, lc=0
((25, 25, 10), 25, False, None), # IC: 25<=25, 25%10!=0 -> hold
((25, 25, 10), 45, False, (20, 45)), # normal: adjusted=25, 25%25==0 -> emit, lc=20
# Per-request IC=10, cs=25: first emit at 10, then 35, 60...
((25, 25, 10), 9, False, None),
((25, 25, 10), 10, False, (0, 10)),
((25, 25, 10), 25, False, None),
((25, 25, 10), 35, False, (10, 35)),
((25, 25, 10), 45, False, None),
((25, 25, 10), 5, True, (0, 5)), # finished flushes IC tail
((25, 25, 10), 33, True, (20, 33)), # finished flushes normal tail
((25, 25, 10), 33, True, (10, 33)), # finished flushes normal tail
#
# IC=8, cs=16: IC evenly divides chunk_size (edge case)
# initial_coverage = (16//8)*8 = 16 == chunk_size.
# IC fills the entire first chunk: emits at 8 and 16.
# Normal phase starts at frame 17; first normal emit at 16+16=32.
((16, 25, 8), 8, False, (0, 8)), # IC: 8%8==0 -> emit, lc=0
((16, 25, 8), 16, False, (8, 16)), # IC: 16<=16, 16%8==0 -> emit, lc=8
((16, 25, 8), 24, False, None), # normal: adjusted=8, 8%16!=0 -> hold
((16, 25, 8), 32, False, (16, 32)), # normal: adjusted=16, 16%16==0 -> first emit, lc=16
# IC=8, cs=16: first emit at 8, then 24, 40...
((16, 25, 8), 8, False, (0, 8)),
((16, 25, 8), 16, False, None),
((16, 25, 8), 24, False, (8, 24)),
((16, 25, 8), 32, False, None),
#
# IC=5, cs=25: IC evenly divides chunk_size
# initial_coverage = (25//5)*5 = 25 == chunk_size.
# IC fills the entire first chunk: emits at 5, 10, 15, 20, 25.
# Normal phase starts at frame 26; first normal emit at 25+25=50.
# Emit intervals: 5,5,5,5,5,25,25,... — smooth transition, no gap.
((25, 25, 5), 5, False, (0, 5)), # IC: 5%5==0 -> emit, lc=0
((25, 25, 5), 12, False, None), # IC: 12%5!=0 -> hold
((25, 25, 5), 25, False, (20, 25)), # IC: 25<=25, 25%5==0 -> emit, lc=20
((25, 25, 5), 30, False, None), # normal: adjusted=5, 5%25!=0 -> hold
((25, 25, 5), 50, False, (25, 50)), # normal: adjusted=25, 25%25==0 -> first emit, lc=25
# IC=5, cs=25: first emit at 5, then 30, 55...
((25, 25, 5), 5, False, (0, 5)),
((25, 25, 5), 12, False, None),
((25, 25, 5), 25, False, None),
((25, 25, 5), 30, False, (5, 30)),
((25, 25, 5), 50, False, None),
#
# Per-request override: IC=15 at n_frames=10 -> 10%15!=0 -> hold
((25, 25, 15), 10, False, None),
Expand Down Expand Up @@ -172,10 +162,10 @@ def test_dynamic_ic_adapts_to_load():
assert p1 is not None
assert len(p1["codes"]["audio"]) == _Q * 2

# High load: add 4 others -> active=5/8 -> IC=8 -> emit at 8
# High load on a new request: active=6/8 -> IC=8 -> emit at 8
for i in range(4):
tm.code_prompt_token_ids[f"other-{i}"] = [[0]]
p2 = _call(tm, "r", n_frames=8)
p2 = _call(tm, "new-high-load", n_frames=8)
assert p2 is not None
assert len(p2["codes"]["audio"]) == _Q * 8

Expand All @@ -201,13 +191,11 @@ def test_ic_load_change_mid_request():
for i in range(6):
tm.code_prompt_token_ids[f"other-{i}"] = [[0]] * 10

# IC for "r" is still cached as 2.
# initial_coverage = ((25-1)//2)*2 = 24, first normal emit at 24+25=49
# IC for "r" is still cached as 2. The first normal emit is at 2+25=27.
assert _call(tm, "r", n_frames=25) is None
assert _call(tm, "r", n_frames=27) is None
p3 = _call(tm, "r", n_frames=49)
p3 = _call(tm, "r", n_frames=27)
assert p3 is not None
assert p3["meta"]["left_context_size"] == 24
assert p3["meta"]["left_context_size"] == 2

# A *new* request under high load gets IC=16 (not IC=2).
# Frame 2 would emit under IC=2 but must hold under IC=16.
Expand All @@ -216,6 +204,25 @@ def test_ic_load_change_mid_request():
assert p4 is not None


def test_connector_initial_chunk_config_overrides_dynamic_ic():
tm = _tm(initial_chunk_frames=4, max_num_seqs=8)

# Under high load dynamic IC would be 16, but connector config pins the
# first chunk to 4 frames.
for i in range(7):
tm.code_prompt_token_ids[f"other-{i}"] = [[0]]

p1 = _call(tm, "r", n_frames=4)
assert p1 is not None
assert len(p1["codes"]["audio"]) == _Q * 4

# Only the first chunk uses the small size; the next emit is 4+25.
assert _call(tm, "r", n_frames=25) is None
p2 = _call(tm, "r", n_frames=29)
assert p2 is not None
assert p2["meta"]["left_context_size"] == 4


@pytest.mark.parametrize(
"active,max_bs,max_ic,expected",
[
Expand Down Expand Up @@ -269,7 +276,7 @@ def test_ref_code_context_applies_to_all_streaming_chunks():
"""ref_code is prepended as decoder context on every chunk, not just the first."""
tm = _tm()
rid = "r-ref2"
tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(20)]
tm.code_prompt_token_ids[rid] = [_FRAME[:] for _ in range(35)]
tm.put_req_chunk[rid] = 1
ref_code = torch.tensor([[9, 9, 9, 9], [8, 8, 8, 8]], dtype=torch.long)
tm.request_payload[rid] = ref_code
Expand All @@ -284,7 +291,7 @@ def test_ref_code_context_applies_to_all_streaming_chunks():
assert payload is not None
# ref_code (2 frames) prepended as left context on second chunk too
assert payload["meta"]["left_context_size"] == 10 + 2
assert len(payload["codes"]["audio"]) == _Q * (20 + 2)
assert len(payload["codes"]["audio"]) == _Q * (35 + 2)


def test_ref_code_context_can_be_buffered_before_first_emit():
Expand Down
Loading
Loading