Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6873ecc
[Feat] Optimize Qwen3-Omni code predictor with torch.compile + bucket…
JuanPZuluaga Mar 31, 2026
f305a02
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 3, 2026
3d08fb3
update and adressed comments
JuanPZuluaga Apr 3, 2026
8b66ed3
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 7, 2026
f88761a
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 7, 2026
5a516d1
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 9, 2026
af185b1
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 9, 2026
4809861
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 10, 2026
b62f9f3
abstract Qwen3TTS and Qwen3Omni CodePredictor into same class
JuanPZuluaga Apr 10, 2026
2d92c5a
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 10, 2026
5eeaa8b
update test
JuanPZuluaga Apr 10, 2026
2ed7dde
add wrapper test
JuanPZuluaga Apr 10, 2026
cb6dc82
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 10, 2026
545723d
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 11, 2026
837d72d
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 12, 2026
39d04e6
Merge branch 'main' into feat/cuda-graph-code-predictor
hsliuustc0106 Apr 12, 2026
fe5e00b
fix conflicts and merge main
JuanPZuluaga Apr 14, 2026
b51b012
Merge branch 'feat/cuda-graph-code-predictor' of github.com:JuanPZulu…
JuanPZuluaga Apr 14, 2026
538eca2
update naming
JuanPZuluaga Apr 14, 2026
28318aa
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 14, 2026
0b12340
_ensure_buffers now takes the actual batch size istead of max_seqs
JuanPZuluaga Apr 15, 2026
0ebb4f6
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 15, 2026
84b0ef9
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 15, 2026
10f174d
Merge branch 'feat/cuda-graph-code-predictor' of https://github.com/J…
JuanPZuluaga Apr 15, 2026
95c389d
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 15, 2026
70064af
merge
JuanPZuluaga Apr 15, 2026
c40d0ae
Merge branch 'main' of https://github.com/vllm-project/vllm-omni into…
JuanPZuluaga Apr 15, 2026
4c4c581
revert example
JuanPZuluaga Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytest_mock import MockerFixture

# Direct file import to avoid vllm_omni.__init__ patch dependencies.
_BASE = os.path.join(
_MODELS = os.path.join(
os.path.dirname(__file__),
os.pardir,
os.pardir,
Expand All @@ -30,14 +30,16 @@
"vllm_omni",
"model_executor",
"models",
"qwen3_tts",
)
_BASE = os.path.join(_MODELS, "qwen3_tts")
_COMMON = os.path.join(_MODELS, "common")


def _load_module(name: str, filename: str):
path = os.path.abspath(os.path.join(_BASE, filename))
spec = importlib.util.spec_from_file_location(name, path)
mod = importlib.util.module_from_spec(spec)
sys.modules[name] = mod # register before exec (needed for dataclasses etc.)
spec.loader.exec_module(mod)
return mod

Expand All @@ -59,8 +61,17 @@ def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]:
weight_utils_mock = mocker.MagicMock()
weight_utils_mock.default_weight_loader = lambda p, w: None

pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts")
pkg.__path__ = [os.path.abspath(_BASE)]
tts_pkg = types.ModuleType("vllm_omni.model_executor.models.qwen3_tts")
tts_pkg.__path__ = [os.path.abspath(_BASE)]

common_pkg = types.ModuleType("vllm_omni.model_executor.models.common")
common_pkg.__path__ = [os.path.abspath(_COMMON)]

models_pkg = types.ModuleType("vllm_omni.model_executor.models")
models_pkg.__path__ = [os.path.abspath(_MODELS)]

vllm_parallel_mock = mocker.MagicMock()
vllm_parallel_mock.VocabParallelEmbedding = torch.nn.Embedding

return {
"vllm_omni": mocker.MagicMock(),
Expand All @@ -69,9 +80,11 @@ def _build_mock_modules(mocker: MockerFixture) -> dict[str, object]:
"vllm.config": mocker.MagicMock(),
"vllm.config.vllm": vllm_config_mod,
"vllm.model_executor.model_loader.weight_utils": weight_utils_mock,
"vllm.model_executor.layers.vocab_parallel_embedding": vllm_parallel_mock,
"vllm_omni.model_executor": types.ModuleType("vllm_omni.model_executor"),
"vllm_omni.model_executor.models": types.ModuleType("vllm_omni.model_executor.models"),
"vllm_omni.model_executor.models.qwen3_tts": pkg,
"vllm_omni.model_executor.models": models_pkg,
"vllm_omni.model_executor.models.common": common_pkg,
"vllm_omni.model_executor.models.qwen3_tts": tts_pkg,
}


Expand All @@ -88,6 +101,15 @@ def _load_target_classes(mocker: MockerFixture):
)
sys.modules["vllm_omni.model_executor.models.qwen3_tts.configuration_qwen3_tts"] = config_mod

# Load the shared common module (thin wrappers import from it)
common_cp_path = os.path.abspath(os.path.join(_COMMON, "qwen3_code_predictor.py"))
common_spec = importlib.util.spec_from_file_location(
"vllm_omni.model_executor.models.common.qwen3_code_predictor", common_cp_path
)
common_cp_mod = importlib.util.module_from_spec(common_spec)
sys.modules["vllm_omni.model_executor.models.common.qwen3_code_predictor"] = common_cp_mod
common_spec.loader.exec_module(common_cp_mod)

cp_mod = _load_module(
"vllm_omni.model_executor.models.qwen3_tts.qwen3_tts_code_predictor_vllm",
"qwen3_tts_code_predictor_vllm.py",
Expand All @@ -104,6 +126,7 @@ def loaded_target_classes(mocker: MockerFixture):
config_mod.Qwen3TTSTalkerConfig,
cp_mod.Qwen3TTSTalkerCodePredictorForConditionalGenerationVLLM,
cp_mod.Qwen3TTSTalkerCodePredictorModelVLLM,
cp_mod.CodePredictorWrapperConfig,
)


Expand All @@ -114,6 +137,7 @@ def _make_tiny_config(loaded_target_classes) -> tuple:
qwen3_tts_talker_config,
_,
_,
_,
) = loaded_target_classes
cp_config = qwen3_tts_talker_code_predictor_config(
vocab_size=64,
Expand Down Expand Up @@ -145,7 +169,7 @@ class TestCodePredictorDtypeAlignment:

def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_ensure_buffers should create proj_buf with the given dtype."""
_, _, code_predictor_wrapper, _ = loaded_target_classes
_, _, code_predictor_wrapper, _, _ = loaded_target_classes
cp_config, talker_config = _make_tiny_config(loaded_target_classes)
vllm_config = _make_vllm_config(mocker)

Expand All @@ -156,17 +180,17 @@ def test_ensure_buffers_uses_given_dtype(self, mocker: MockerFixture, loaded_tar
)

# Create buffer in float16
predictor._ensure_buffers(torch.device("cpu"), torch.float16)
predictor._ensure_buffers(torch.device("cpu"), torch.float16, 4)
assert predictor._proj_buf is not None
assert predictor._proj_buf.dtype == torch.float16

# Re-create buffer in float32 (different dtype triggers re-allocation)
predictor._ensure_buffers(torch.device("cpu"), torch.float32)
predictor._ensure_buffers(torch.device("cpu"), torch.float32, 4)
assert predictor._proj_buf.dtype == torch.float32

def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_warmup_buckets should align proj_buf dtype to model parameters."""
_, _, code_predictor_wrapper, _ = loaded_target_classes
_, _, code_predictor_wrapper, _, _ = loaded_target_classes
cp_config, talker_config = _make_tiny_config(loaded_target_classes)
vllm_config = _make_vllm_config(mocker, max_num_seqs=2)

Expand All @@ -180,7 +204,7 @@ def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loade
predictor = predictor.to(torch.float16)

# Pre-create proj_buf with WRONG dtype (float32) — simulating the bug
predictor._ensure_buffers(torch.device("cpu"), torch.float32)
predictor._ensure_buffers(torch.device("cpu"), torch.float32, 2)
assert predictor._proj_buf.dtype == torch.float32

# Simulate _setup_compile having cached model dtype and compiled forward
Expand All @@ -194,7 +218,7 @@ def test_warmup_aligns_buffer_to_model_params(self, mocker: MockerFixture, loade

def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""_setup_compile should cache model parameter dtype."""
_, _, code_predictor_wrapper, _ = loaded_target_classes
_, _, code_predictor_wrapper, _, _ = loaded_target_classes
cp_config, talker_config = _make_tiny_config(loaded_target_classes)
vllm_config = _make_vllm_config(mocker, max_num_seqs=2)

Expand All @@ -211,7 +235,7 @@ def test_setup_compile_caches_model_dtype(self, mocker: MockerFixture, loaded_ta

def test_forward_with_mismatched_input_dtype(self, mocker: MockerFixture, loaded_target_classes) -> None:
"""forward() should not crash when inputs are float32 but model is float16."""
_, _, code_predictor_wrapper, _ = loaded_target_classes
_, _, code_predictor_wrapper, _, _ = loaded_target_classes
cp_config, talker_config = _make_tiny_config(loaded_target_classes)
vllm_config = _make_vllm_config(mocker, max_num_seqs=2)

Expand Down Expand Up @@ -250,9 +274,9 @@ class TestCodePredictorModelDtype:

def test_model_forward_float16(self, loaded_target_classes) -> None:
"""Inner model forward should work in float16."""
_, _, _, code_predictor_model = loaded_target_classes
_, _, _, code_predictor_model, _ = loaded_target_classes
cp_config, _ = _make_tiny_config(loaded_target_classes)
model = code_predictor_model(cp_config, talker_hidden_size=32).to(torch.float16)
model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float16)

bsz, seq_len = 1, 4
inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float16)
Expand All @@ -264,9 +288,9 @@ def test_model_forward_float16(self, loaded_target_classes) -> None:

def test_model_forward_float32(self, loaded_target_classes) -> None:
"""Inner model forward should work in float32."""
_, _, _, code_predictor_model = loaded_target_classes
_, _, _, code_predictor_model, _ = loaded_target_classes
cp_config, _ = _make_tiny_config(loaded_target_classes)
model = code_predictor_model(cp_config, talker_hidden_size=32).to(torch.float32)
model = code_predictor_model(cp_config, embedding_dim=32).to(torch.float32)

bsz, seq_len = 1, 4
inputs = torch.randn(bsz, seq_len, 32, dtype=torch.float32)
Expand All @@ -275,3 +299,37 @@ def test_model_forward_float32(self, loaded_target_classes) -> None:
output = model(inputs, pos_ids)
assert output.dtype == torch.float32
assert output.shape == (bsz, seq_len, 32)


class TestCodePredictorWrapperConfig:
"""Test wrapper configuration for different models."""

def test_omni_config(self, loaded_target_classes) -> None:
"""Qwen3-Omni uses correct wrapper config."""
_, _, _, _, code_predictor_wrapper_config = loaded_target_classes
config = code_predictor_wrapper_config(
use_cuda_graphs=False,
use_parallel_embedding=True,
use_projection=False,
return_proj_buf=True,
sampling_mode="stored",
)
assert config.use_cuda_graphs is False
assert config.use_parallel_embedding is True
assert config.return_proj_buf is True
assert config.sampling_mode == "stored"

def test_tts_config(self, loaded_target_classes) -> None:
"""Qwen3-TTS uses correct wrapper config."""
_, _, _, _, code_predictor_wrapper_config = loaded_target_classes
config = code_predictor_wrapper_config(
use_cuda_graphs=True,
use_parallel_embedding=False,
use_projection=True,
return_proj_buf=False,
sampling_mode="per_call",
)
assert config.use_cuda_graphs is True
assert config.use_parallel_embedding is False
assert config.return_proj_buf is False
assert config.sampling_mode == "per_call"
5 changes: 3 additions & 2 deletions vllm_omni/engine/stage_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ def extract_stage_metadata(stage_config: Any) -> StageMetadata:
default_sampling_params: OmniSamplingParams = SPClass(**default_sp)

custom_process_input_func: Callable | None = None
if hasattr(stage_config, "custom_process_input_func"):
mod_path, fn_name = stage_config.custom_process_input_func.rsplit(".", 1)
_cpif_path = getattr(stage_config, "custom_process_input_func", None)
if _cpif_path:
mod_path, fn_name = _cpif_path.rsplit(".", 1)
custom_process_input_func = getattr(importlib.import_module(mod_path), fn_name)

prompt_expand_func: Callable | None = None
Expand Down
Empty file.
Loading
Loading