-
Notifications
You must be signed in to change notification settings - Fork 1k
[Feature] Support Qwen3 Omni talker mtp batch inference #722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6c5a143
90ec453
1aa251a
025f29f
9d3a984
2147982
f3afd1d
eb2562e
3df2f24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| from contextlib import contextmanager | ||
|
|
||
| import torch | ||
|
|
||
| from vllm_omni.worker.gpu_model_runner import OmniGPUModelRunner | ||
|
|
||
|
|
||
| class DummyBuffer: | ||
| """A minimal buffer wrapper that exposes the `.gpu` attribute.""" | ||
|
|
||
| def __init__(self, t: torch.Tensor): | ||
| self.gpu = t | ||
|
|
||
|
|
||
| class DummyInputBatch: | ||
| """A minimal input batch that only provides `req_ids`.""" | ||
|
|
||
| def __init__(self, req_ids): | ||
| self.req_ids = req_ids | ||
|
|
||
|
|
||
| class DummyReqState: | ||
| """A minimal request state container.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class DummyTalkerMTP(torch.nn.Module): | ||
| """A fake talker_mtp module for deterministic CPU testing.""" | ||
|
|
||
| def forward(self, req_input_ids, req_embeds, last_talker_hidden, text_step): | ||
| # Deterministic behavior: | ||
| # - output embeds = input embeds + 1 | ||
| # - output codes = [[0], [1], ...] | ||
| bsz = req_embeds.shape[0] | ||
| new_embeds = req_embeds + 1.0 | ||
| codes = torch.arange(bsz, dtype=torch.int64).view(bsz, 1) | ||
| return new_embeds, codes | ||
|
|
||
|
|
||
| @contextmanager | ||
| def _noop_forward_context(*args, **kwargs): | ||
| """A no-op context manager to replace vLLM forward context in CPU tests.""" | ||
| yield | ||
|
|
||
|
|
||
| def _make_runner(req_ids=("r1", "r2"), hidden_size=4): | ||
| # Create an instance without calling OmniGPUModelRunner.__init__ | ||
| runner = object.__new__(OmniGPUModelRunner) | ||
|
|
||
| # Minimal attributes used by OmniGPUModelRunner._talker_mtp_forward | ||
| runner.input_batch = DummyInputBatch(list(req_ids)) | ||
| runner.requests = {rid: DummyReqState() for rid in req_ids} | ||
|
|
||
| # query_start_loc.cpu[req_index] is used to locate the token position | ||
| # in the flattened `inputs_embeds`. | ||
| runner.query_start_loc = type("QSL", (), {})() | ||
| # Map: r1 -> offset 0, r2 -> offset 3 | ||
| runner.query_start_loc.cpu = torch.tensor([0, 3], dtype=torch.int32) | ||
|
|
||
| bsz = len(req_ids) | ||
| runner.talker_mtp_input_ids = DummyBuffer(torch.zeros((bsz,), dtype=torch.int64)) | ||
| runner.talker_mtp_inputs_embeds = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) | ||
| runner.last_talker_hidden = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) | ||
| runner.text_step = DummyBuffer(torch.zeros((bsz, hidden_size), dtype=torch.float32)) | ||
|
|
||
| runner.talker_mtp = DummyTalkerMTP() | ||
| runner.vllm_config = object() | ||
|
|
||
| # Provide a minimal implementation that returns the expected 4-tuple. | ||
| def _determine_batch_execution_and_padding(**kwargs): | ||
| return None, object(), None, None | ||
|
|
||
| runner._determine_batch_execution_and_padding = _determine_batch_execution_and_padding | ||
|
|
||
| # Use the real merge method from OmniGPUModelRunner. | ||
| return runner | ||
|
|
||
|
|
||
| def test_talker_mtp_forward_cpu_updates_inputs_and_info(monkeypatch): | ||
| # Patch the module-level `set_forward_context` symbol used inside | ||
| # OmniGPUModelRunner._talker_mtp_forward. | ||
| import vllm_omni.worker.gpu_model_runner as mod # Must be the same module that defines OmniGPUModelRunner | ||
|
|
||
| monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) | ||
|
|
||
| runner = _make_runner(req_ids=("r1", "r2"), hidden_size=4) | ||
|
|
||
| # Initialize per-request embeds (batch-major inside talker_mtp_inputs_embeds) | ||
| runner.talker_mtp_inputs_embeds.gpu[0] = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||
| runner.talker_mtp_inputs_embeds.gpu[1] = torch.tensor([10.0, 20.0, 30.0, 40.0]) | ||
|
|
||
| # Flattened `inputs_embeds`: offsets 0 and 3 will be overwritten | ||
| inputs_embeds = torch.zeros((6, 4), dtype=torch.float32) | ||
|
|
||
| # Call the original implementation from OmniGPUModelRunner (no re-implementation) | ||
| OmniGPUModelRunner._talker_mtp_forward(runner, ["r1", "r2"], inputs_embeds) | ||
|
|
||
| # Validate embeds were written back (+1) | ||
| assert torch.allclose(inputs_embeds[0], torch.tensor([2.0, 3.0, 4.0, 5.0])) | ||
| assert torch.allclose(inputs_embeds[3], torch.tensor([11.0, 21.0, 31.0, 41.0])) | ||
|
|
||
| # Validate per-request additional_information_cpu was updated | ||
| info_r1 = runner.requests["r1"].additional_information_cpu | ||
| info_r2 = runner.requests["r2"].additional_information_cpu | ||
| assert int(info_r1["code_predictor_codes"][0, 0]) == 0 | ||
| assert int(info_r2["code_predictor_codes"][0, 0]) == 1 | ||
|
|
||
|
|
||
| def test_talker_mtp_forward_cpu_empty_batch_noop(monkeypatch): | ||
| import vllm_omni.worker.gpu_model_runner as mod | ||
|
|
||
| monkeypatch.setattr(mod, "set_forward_context", _noop_forward_context) | ||
|
|
||
| runner = _make_runner(req_ids=("r1",), hidden_size=4) | ||
|
|
||
| inputs_embeds = torch.randn((2, 4)) | ||
| before = inputs_embeds.clone() | ||
|
|
||
| OmniGPUModelRunner._talker_mtp_forward(runner, [], inputs_embeds) | ||
|
|
||
| # Ensure no changes were made | ||
| assert torch.allclose(inputs_embeds, before) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -563,11 +563,7 @@ def _dummy_run( | |
| ubatch_slices=ubatch_slices, | ||
| ), | ||
| ): | ||
| if ( | ||
| getattr(self.model, "talker", None) is not None | ||
| and hasattr(self.model, "talker_mtp") | ||
| and num_tokens_padded == 1 | ||
| ): | ||
| if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): | ||
| outputs = self.talker_mtp( | ||
| self.talker_mtp_input_ids.gpu[:num_tokens_padded], | ||
| self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded], | ||
|
|
@@ -884,6 +880,7 @@ def _preprocess( | |
| if hasattr(self.model, "has_preprocess") and self.model.has_preprocess: | ||
| # Overlay custom prompt_embeds per request for the prompt portion; | ||
| # collect additional_information (tensor/list) for prefill portion only | ||
| decode_req_ids = [] | ||
| for req_index, req_id in enumerate(self.input_batch.req_ids): | ||
| req_state = self.requests.get(req_id) | ||
| req_infos = getattr(req_state, "additional_information_cpu", None) if req_state is not None else None | ||
|
|
@@ -897,33 +894,14 @@ def _preprocess( | |
| req_input_ids, req_embeds, update_dict = self.model.preprocess( | ||
| input_ids=input_ids[s:e], input_embeds=inputs_embeds[s:e], **req_infos | ||
| ) | ||
| # run talker mtp decode | ||
| if hasattr(self.model, "talker_mtp"): | ||
| _cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding( | ||
| num_tokens=span_len, | ||
| num_reqs=1, | ||
| num_scheduled_tokens_np=num_scheduled_tokens_np[req_index], | ||
| max_num_scheduled_tokens=1, | ||
| force_eager=span_len > 1, | ||
| use_cascade_attn=False, | ||
| ) | ||
| if hasattr(self.model, "talker_mtp") and span_len == 1: | ||
| last_talker_hidden, text_step = update_dict.pop("mtp_inputs") | ||
| if _cudagraph_mode != CUDAGraphMode.NONE: | ||
| self.talker_mtp_input_ids.gpu[:span_len].copy_(req_input_ids) | ||
| self.talker_mtp_inputs_embeds.gpu[:span_len].copy_(req_embeds) | ||
| self.last_talker_hidden.gpu[:span_len].copy_(last_talker_hidden) | ||
| self.text_step.gpu[:span_len].copy_(text_step) | ||
| req_input_ids = self.talker_mtp_input_ids.gpu[:span_len] | ||
| req_embeds = self.talker_mtp_inputs_embeds.gpu[:span_len] | ||
| last_talker_hidden = self.last_talker_hidden.gpu[:span_len] | ||
| text_step = self.text_step.gpu[:span_len] | ||
| with set_forward_context( | ||
| None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc | ||
| ): | ||
| req_embeds, code_predictor_codes = self.talker_mtp( | ||
| req_input_ids, req_embeds, last_talker_hidden, text_step | ||
| ) | ||
| update_dict["code_predictor_codes"] = code_predictor_codes | ||
| decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) | ||
| self.talker_mtp_input_ids.gpu[decode_slice].copy_(req_input_ids) | ||
| self.talker_mtp_inputs_embeds.gpu[decode_slice].copy_(req_embeds) | ||
| self.last_talker_hidden.gpu[decode_slice].copy_(last_talker_hidden) | ||
| self.text_step.gpu[decode_slice].copy_(text_step) | ||
| decode_req_ids.append(req_id) | ||
|
|
||
| # TODO(Peiqi): the merge stage could move out from the critical path | ||
| self._merge_additional_information_update(req_id, update_dict) | ||
|
|
@@ -934,6 +912,10 @@ def _preprocess( | |
| if isinstance(req_input_ids, torch.Tensor) and req_input_ids.numel() == seg_len: | ||
| input_ids[s : s + seg_len] = req_input_ids | ||
|
|
||
| # run talker mtp decode | ||
| if hasattr(self.model, "talker_mtp"): | ||
| self._talker_mtp_forward(decode_req_ids, inputs_embeds) | ||
|
|
||
| return ( | ||
| input_ids, | ||
| inputs_embeds, | ||
|
|
@@ -943,6 +925,34 @@ def _preprocess( | |
| ec_connector_output, | ||
| ) | ||
|
|
||
| def _talker_mtp_forward(self, decode_req_ids: list[str], inputs_embeds: torch.Tensor) -> None: | ||
| decode_batch_size = len(decode_req_ids) | ||
| if decode_batch_size == 0: | ||
| return | ||
| _cudagraph_mode, batch_desc, _, _ = self._determine_batch_execution_and_padding( | ||
| num_tokens=decode_batch_size, | ||
| num_reqs=decode_batch_size, | ||
| num_scheduled_tokens_np=np.ones(decode_batch_size, dtype=np.int32), | ||
| max_num_scheduled_tokens=1, | ||
| use_cascade_attn=False, | ||
| ) | ||
| req_input_ids = self.talker_mtp_input_ids.gpu[:decode_batch_size] | ||
| req_embeds = self.talker_mtp_inputs_embeds.gpu[:decode_batch_size] | ||
| last_talker_hidden = self.last_talker_hidden.gpu[:decode_batch_size] | ||
| text_step = self.text_step.gpu[:decode_batch_size] | ||
| with set_forward_context( | ||
| None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc | ||
| ): | ||
| req_embeds, code_predictor_codes = self.talker_mtp(req_input_ids, req_embeds, last_talker_hidden, text_step) | ||
| # update the inputs_embeds and code_predictor_codes | ||
| code_predictor_codes_cpu = code_predictor_codes.detach().to("cpu").contiguous() | ||
| for idx, req_id in enumerate(decode_req_ids): | ||
| req_index = self.input_batch.req_ids.index(req_id) | ||
| start_offset = int(self.query_start_loc.cpu[req_index]) | ||
| inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By this code we can directly use inputs_embeds = req_embeds? then avoid iterate each request. I can remove the update of "code_predictor_codes" and "merge update" here. If so, we maybe able to directly place batch_mtp_forward inside the model forward.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. req_embeds does not include prefill requests. Assigning inputs_embeds = req_embeds directly maybe incorrect for mixed prefill and decode batch. |
||
| update_dict = {"code_predictor_codes": code_predictor_codes_cpu[idx : idx + 1]} | ||
| self._merge_additional_information_update(req_id, update_dict) | ||
|
|
||
| def _model_forward( | ||
| self, | ||
| input_ids: torch.Tensor | None = None, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This call now runs for every batch, but the
talker_mtp_*buffers are only populated for decode requests in_prepare_inputswhenspan_len == 1. For prefill-only batches (or any batch wherenum_tokens_paddedexceeds the number of decode requests), the slice[:num_tokens_padded]pulls stale/uninitialized entries (and can exceed themax_num_reqs-sized buffers), which can drivetalker_mtpwith invalid token IDs/hidden states and trigger out-of-range embedding errors or random sampling on garbage data. This should be guarded by the actual decode batch size (or removed) and sliced by that size rather thannum_tokens_padded.Useful? React with 👍 / 👎.