diff --git a/.buildkite/scripts/simple_test.sh b/.buildkite/scripts/simple_test.sh index 33248d99cde..55ac27cec9f 100755 --- a/.buildkite/scripts/simple_test.sh +++ b/.buildkite/scripts/simple_test.sh @@ -52,3 +52,4 @@ VENV_PYTHON="${VENV_DIR}/bin/python" "${VENV_PYTHON}" -m pytest -v -s tests/entrypoints/ "${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/ "${VENV_PYTHON}" -m pytest -v -s tests/model_executor/models/qwen2_5_omni/test_audio_length.py +"${VENV_PYTHON}" -m pytest -v -s tests/worker/ diff --git a/tests/worker/test_omni_gpu_model_runner.py b/tests/worker/test_omni_gpu_model_runner.py new file mode 100644 index 00000000000..b0132306c81 --- /dev/null +++ b/tests/worker/test_omni_gpu_model_runner.py @@ -0,0 +1,123 @@ +from contextlib import contextmanager + +import torch + +from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner + + +class DummyBuffer: + """A minimal buffer wrapper that exposes the `.gpu` attribute.""" + + def __init__(self, t: torch.Tensor): + self.gpu = t + + +class DummyInputBatch: + """A minimal input batch that only provides `req_ids`.""" + + def __init__(self, req_ids): + self.req_ids = req_ids + + +class DummyReqState: + """A minimal request state container.""" + + pass + + +class DummyTalkerMTP(torch.nn.Module): + """A fake talker_mtp module for deterministic CPU testing.""" + + def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step): + # Deterministic behavior: + # - output embeds = input embeds + 1 + # - output codes = [[0], [1], ...] + bsz = req_embeds.shape[0] + new_embeds = req_embeds + 1.0 + codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1) + return new_embeds, codes + + +@contextmanager +def _noop_forward_context(*args, **kwargs): + """A no-op context manager to replace vLLM forward context in CPU tests.""" + yield + + +def _make_runner(req_ids=("r1", "r2"), hidden_size=4): + # Create an instance without calling OmniGPUModelRunner.__init__ + runner = object.__new__(OmniGPUModelRunner) + + # Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward + runner.input_batch = DummyInputBatch(list(req_ids)) + runner.requests = {rid: DummyReqState() for rid in req_ids} + + # query_start_loc.cpu[req_index] is used to locate the token position + # in the flattened `inputs_embeds`. + runner.query_start_loc = type("QSL", (), {})() + # Map: r1 -> offset 0, r2 -> offset 3 + runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32) + + bsz = len(req_ids) + runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64)) + runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) + + runner.talker_mtp = DummyTalkerMTP() + runner.vllm_config = object() + + # Provide a minimal implementation that returns the expected 4-tuple. + def _determine_batch_execution_and_padding(**kwargs): + return None, object(), None, None + + runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding + + # Use the real merge method from OmniGPUModelRunner. + return runner + + +def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): + # Patch the module-level `set_forward_context` symbol used inside + # OmniGPUModelRunner._talker_mtp_forward. + import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4) + + # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds) + runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) + runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) + + # Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten + inputs_embeds = torch.zeros((6, 4), dtype=torch.float32) + + # Call the original implementation from OmniGPUModelRunner (no re-implementation) + OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds) + + # Validate embeds were written back (+1) + assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0])) + assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0])) + + # Validate per-request additional_information_cpu was updated + info_r1 = runner.requests["r1"].additional_information_cpu + info_r2 = runner.requests["r2"].additional_information_cpu + assert int(info_r1["code_predictor_codes"][0, 0]) == 0 + assert int(info_r2["code_predictor_codes"][0, 0]) == 1 + + +def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch): + import vllm_omni.worker.gpu_model_runner as mod + + monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) + + runner = _make_runner(req_ids=("r1",), hidden_size=4) + + inputs_embeds = torch.randn((2, 4)) + before = inputs_embeds.clone() + + OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds) + + # Ensure no changes were made + assert torch.allclose(inputs_embeds, before) diff --git a/vllm_omni/entrypoints/omni_stage.py b/vllm_omni/entrypoints/omni_stage.py index af6f60f0420..804ab7b7fb8 100644 --- a/vllm_omni/entrypoints/omni_stage.py +++ b/vllm_omni/entrypoints/omni_stage.py @@ -951,8 +951,6 @@ async def _stage_worker_async( except Exception as e: logger.warning("Device setup failed: %s", e) - max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) - engine_args["max_num_seqs"] = max_batch_size # Initialize OmniConnectors if configured to match sync worker behavior connectors: dict[Any, Any] = {} if connectors_config: diff --git a/vllm_omni/entrypoints/utils.py b/vllm_omni/entrypoints/utils.py index 135b0e89ff2..eae3ea7afc4 100644 --- a/vllm_omni/entrypoints/utils.py +++ b/vllm_omni/entrypoints/utils.py @@ -195,6 +195,10 @@ def load_stage_configs_from_yaml(config_path: str, base_engine_args: dict | None # Update base_engine_args with stage-specific engine_args if they exist if hasattr(stage_arg, "engine_args") and stage_arg.engine_args is not None: base_engine_args_tmp = OmegaConf.merge(base_engine_args_tmp, stage_arg.engine_args) + if hasattr(stage_arg, "runtime") and stage_arg.runtime is not None: + runtime_cfg = stage_arg.runtime + max_batch_size = int(runtime_cfg.get("max_batch_size", 1) or 1) + base_engine_args_tmp["max_num_seqs"] = max_batch_size stage_arg.engine_args = base_engine_args_tmp return stage_args diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py index 7675caa638e..ba46d4a1483 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni.py @@ -573,30 +573,22 @@ def talker_preprocess(self, input_ids: torch.Tensor, input_embeds: torch.Tensor, if input_embeds is None and input_ids is not None: input_embeds = self.talker.embed_input_ids(input_ids) - text_step = torch.zeros( - 1, - self.talker_config.text_config.hidden_size, - device=self._module_device(self.talker), - dtype=torch.bfloat16, - ) - last_talker_hidden = torch.zeros( - 1, - 1, - self.talker_config.text_config.hidden_size, - device=self._module_device(self.talker), - dtype=torch.bfloat16, - ) - span_len = input_ids.shape[0] if span_len > 1: # prefill input_ids, input_embeds, update_dict = self.talker_preprocess_prefill(input_ids, input_embeds, **info_dict) + code_predictor_codes = torch.zeros( + (input_embeds.shape[0], self.talker.num_code_groups), + device=self._module_device(self.talker), + dtype=torch.long, + ) + update_dict["code_predictor_codes"] = code_predictor_codes else: last_talker_hidden, text_step, update_dict = self.talker_preprocess_decode( input_ids, input_embeds, **info_dict ) - update_dict["mtp_inputs"] = last_talker_hidden, text_step + update_dict["mtp_inputs"] = last_talker_hidden, text_step return input_ids, input_embeds, update_dict @@ -608,24 +600,19 @@ def talker_mtp( text_step: torch.Tensor, ): # TODO(Peiqi): not support intermediate_tensors now - input_ids = safe_tensor_reshape(input_ids, (1, -1)) + input_ids = safe_tensor_reshape(input_ids, (input_ids.shape[0], -1)) inputs_embeds = safe_tensor_reshape(input_embeds, (-1, self.talker_config.text_config.hidden_size)) - text_step = safe_tensor_reshape(text_step, (1, -1)) - last_talker_hidden = safe_tensor_reshape(last_talker_hidden, (1, 1, self.talker_config.text_config.hidden_size)) + text_step = safe_tensor_reshape(text_step, (-1, self.talker_config.text_config.hidden_size)) + last_talker_hidden = safe_tensor_reshape( + last_talker_hidden, (-1, 1, self.talker_config.text_config.hidden_size) + ) # for profiling if inputs_embeds.shape[-1] == 2048: inputs_embeds = self.text_projection(inputs_embeds) - if inputs_embeds.shape[0] == 1: - code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward( - input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden - ) - inputs_embeds = summed_embeddings.clone() - else: - code_predictor_codes = torch.zeros( - (inputs_embeds.shape[0], self.talker.num_code_groups), - device=self._module_device(self.talker), - dtype=torch.long, - ) + code_predictor_codes, summed_embeddings = self.talker.code_predictor_forward( + input_ids, inputs_embeds.clone(), last_talker_hidden=last_talker_hidden + ) + inputs_embeds = summed_embeddings.clone() inputs_embeds = (inputs_embeds + text_step).reshape(-1, self.talker_config.text_config.hidden_size) return inputs_embeds, code_predictor_codes.squeeze(-1) @@ -848,7 +835,7 @@ def talker_preprocess_decode(self, input_ids: torch.Tensor, input_embeds: torch. use_vec = q_tail[0:1, :] new_q_tail = ( q_tail[1:, :].detach().to("cpu").contiguous() - if q_tail.shape[1] > 1 + if q_tail.shape[0] > 1 else self.tts_pad_embed.to(input_embeds.device, dtype=input_embeds.dtype) ) text_step = use_vec.to(input_embeds.device, dtype=input_embeds.dtype) diff --git a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py index 4e8730eab52..2f1893e00ca 100644 --- a/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py +++ b/vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py @@ -234,13 +234,7 @@ def code_predictor_forward( # Use the corresponding lm_head for this layer logits = self.code_predictor.lm_head[layer_idx](hidden_state[:, -1:, :]) # [batch, 1, vocab_size] - if len(pos_codes) > 1: - input_ids_for_logits_processors = torch.cat(pos_codes[1:], dim=1).to( - device=logits.device, dtype=torch.long - ) - else: - input_ids_for_logits_processors = self.empty_code - logits = logits_processors(input_ids_for_logits_processors, logits.squeeze(0)).unsqueeze(0) + logits = logits_processors(None, logits[:, -1]) # Sample from the filtered distribution probs = F.softmax(logits, dim=-1) @@ -288,7 +282,7 @@ def code_predictor_forward( all_summed_embeddings.append(pos_summed) # Concatenate across positions: [batch, seq_len, hidden_size] - summed_embeddings = torch.cat(all_summed_embeddings, dim=1) + summed_embeddings = torch.cat(all_summed_embeddings, dim=1).squeeze(1) return result_codes, summed_embeddings diff --git a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py index 246ea2996e8..a1457a9750b 100644 --- a/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py +++ b/vllm_omni/model_executor/stage_input_processors/qwen3_omni.py @@ -160,7 +160,7 @@ def talker2code2wav( # Process each talker output for i, talker_output in enumerate(talker_outputs): output = talker_output.outputs[0] - seq_len = len(output.token_ids) + seq_len = len(output.token_ids) - 1 # Extract codec codes from talker output # Expected shape: [8, seq_len] (8-layer RVQ codes) codec_codes = ( diff --git a/vllm_omni/worker/gpu_model_runner.py b/vllm_omni/worker/gpu_model_runner.py index 69729d95429..b0d0e165e08 100644 --- a/vllm_omni/worker/gpu_model_runner.py +++ b/vllm_omni/worker/gpu_model_runner.py @@ -563,11 +563,7 @@ def _dummy_run( ubatch_slices=ubatch_slices, ), ): - if ( - getattr(self.model, "talker", None) is not None - and hasattr(self.model, "talker_mtp") - and num_tokens_padded == 1 - ): + if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): outputs = self.talker_mtp( self.talker_mtp_input_ids.gpu[:num_tokens_padded], self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded], @@ -884,6 +880,7 @@ def _preprocess( if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: # Overlay custom prompt_embeds per request for the prompt portion; # collect additional_information (tensor/list) for prefill portion only + decode_req_ids = [] for req_index, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests.get(req_id) req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None @@ -897,33 +894,14 @@ def _preprocess( req_input_ids, req_embeds, update_dict = self.model.preprocess( input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos ) - # run talker mtp decode - if hasattr(self.model, "talker_mtp"): - _cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding( - num_tokens=span_len, - num_reqs=1, - num_scheduled_tokens_np=num_scheduled_tokens_np[req_index], - max_num_scheduled_tokens=1, - force_eager=span_len > 1, - use_cascade_attn=False, - ) + if hasattr(self.model, "talker_mtp") and span_len == 1: last_talker_hidden, text_step = update_dict.pop("mtp_inputs") - if _cudagraph_mode != CUDAGraphMode.NONE: - self.talker_mtp_input_ids.gpu[:span_len].copy_(req_input_ids) - self.talker_mtp_inputs_embeds.gpu[:span_len].copy_(req_embeds) - self.last_talker_hidden.gpu[:span_len].copy_(last_talker_hidden) - self.text_step.gpu[:span_len].copy_(text_step) - req_input_ids = self.talker_mtp_input_ids.gpu[:span_len] - req_embeds = self.talker_mtp_inputs_embeds.gpu[:span_len] - last_talker_hidden = self.last_talker_hidden.gpu[:span_len] - text_step = self.text_step.gpu[:span_len] - with set_forward_context( - None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc - ): - req_embeds, code_predictor_codes = self.talker_mtp( - req_input_ids, req_embeds, last_talker_hidden, text_step - ) - update_dict["code_predictor_codes"] = code_predictor_codes + decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) + self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) + self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) + self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) + self.text_step.gpu[decode_slice].copy_(text_step) + decode_req_ids.append(req_id) # TODO(Peiqi): the merge stage could move out from the critical path self._merge_additional_information_update(req_id, update_dict) @@ -934,6 +912,10 @@ def _preprocess( if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: input_ids[s : s + seg_len] = req_input_ids + # run talker mtp decode + if hasattr(self.model, "talker_mtp"): + self._talker_mtp_forward(decode_req_ids, inputs_embeds) + return ( input_ids, inputs_embeds, @@ -943,6 +925,34 @@ def _preprocess( ec_connector_output, ) + def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: + decode_batch_size = len(decode_req_ids) + if decode_batch_size == 0: + return + _cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding( + num_tokens=decode_batch_size, + num_reqs=decode_batch_size, + num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), + max_num_scheduled_tokens=1, + use_cascade_attn=False, + ) + req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size] + req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size] + last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size] + text_step = self.text_step.gpu[:decode_batch_size] + with set_forward_context( + None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc + ): + req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) + # update the inputs_embeds and code_predictor_codes + code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() + for idx, req_id in enumerate(decode_req_ids): + req_index = self.input_batch.req_ids.index(req_id) + start_offset = int(self.query_start_loc.cpu[req_index]) + inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] + update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} + self._merge_additional_information_update(req_id, update_dict) + def _model_forward( self, input_ids: torch.Tensor | None = None,