[Feature] Support Qwen3 Omni talker mtp batch inference#722
Conversation
Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6c5a143c41
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| decode_req_ids.append(req_id) | ||
| last_talker_hidden, text_step = update_dict.pop("mtp_inputs") | ||
| if _cudagraph_mode != CUDAGraphMode.NONE: | ||
| self.talker_mtp_input_ids.gpu[:span_len].copy_(req_input_ids) | ||
| self.talker_mtp_inputs_embeds.gpu[:span_len].copy_(req_embeds) | ||
| self.last_talker_hidden.gpu[:span_len].copy_(last_talker_hidden) | ||
| self.text_step.gpu[:span_len].copy_(text_step) | ||
| req_input_ids = self.talker_mtp_input_ids.gpu[:span_len] | ||
| req_embeds = self.talker_mtp_inputs_embeds.gpu[:span_len] | ||
| last_talker_hidden = self.last_talker_hidden.gpu[:span_len] | ||
| text_step = self.text_step.gpu[:span_len] | ||
| with set_forward_context( | ||
| None, self.vllm_config, cudagraph_runtime_mode=_cudagraph_mode, batch_descriptor=batch_desc | ||
| ): | ||
| req_embeds, code_predictor_codes = self.talker_mtp( | ||
| req_input_ids, req_embeds, last_talker_hidden, text_step | ||
| ) | ||
| update_dict["code_predictor_codes"] = code_predictor_codes | ||
| decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) |
There was a problem hiding this comment.
Fix decode buffer indexing off-by-one
The decode buffers are filled using decode_slice = slice(len(decode_req_ids), len(decode_req_ids) + 1) immediately after appending the current request, which makes the first decode request write into index 1 instead of 0. _talker_mtp_forward then reads [:decode_batch_size], so the first row is stale and the last request’s data is never read. In multi-request decode batches this silently corrupts embeddings/codes for at least one request.
Useful? React with 👍 / 👎.
| # run talker mtp decode | ||
| if hasattr(self.model, "talker_mtp"): | ||
| self._talker_mtp_forward(decode_req_ids, inputs_embeds) |
There was a problem hiding this comment.
Guard talker MTP when no decode requests
_talker_mtp_forward is called unconditionally whenever the model has talker_mtp, even if decode_req_ids is empty (e.g., a prefill-only batch where all span_len > 1). This makes decode_batch_size=0, and talker_mtp ultimately calls code_predictor_forward with seq_len=0, which hits torch.cat(all_codes_per_position, dim=2) on an empty list and raises a runtime error. This will crash prefill-only batches; add a guard to skip _talker_mtp_forward when there are no decode requests.
Useful? React with 👍 / 👎.
| for idx, req_id in enumerate(decode_req_ids): | ||
| req_index = self.input_batch.req_ids.index(req_id) | ||
| start_offset = int(self.query_start_loc.cpu[req_index]) | ||
| inputs_embeds[start_offset : start_offset + 1] = req_embeds[idx : idx + 1] |
There was a problem hiding this comment.
By this code we can directly use inputs_embeds = req_embeds? then avoid iterate each request. I can remove the update of "code_predictor_codes" and "merge update" here. If so, we maybe able to directly place batch_mtp_forward inside the model forward.
There was a problem hiding this comment.
req_embeds does not include prefill requests. Assigning inputs_embeds = req_embeds directly maybe incorrect for mixed prefill and decode batch.
Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 2147982303
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if getattr(self.model, "talker", None) is not None and hasattr(self.model, "talker_mtp"): | ||
| outputs = self.talker_mtp( | ||
| self.talker_mtp_input_ids.gpu[:num_tokens_padded], | ||
| self.talker_mtp_inputs_embeds.gpu[:num_tokens_padded], |
There was a problem hiding this comment.
Guard talker_mtp call to decode-populated buffers
This call now runs for every batch, but the talker_mtp_* buffers are only populated for decode requests in _prepare_inputs when span_len == 1. For prefill-only batches (or any batch where num_tokens_padded exceeds the number of decode requests), the slice [:num_tokens_padded] pulls stale/uninitialized entries (and can exceed the max_num_reqs-sized buffers), which can drive talker_mtp with invalid token IDs/hidden states and trigger out-of-range embedding errors or random sampling on garbage data. This should be guarded by the actual decode batch size (or removed) and sliced by that size rather than num_tokens_padded.
Useful? React with 👍 / 👎.
Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
Gaohan123
left a comment
There was a problem hiding this comment.
Could you please add a UT test to protect key methods? Thanks!
Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
Done. |
…#722) Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com> Signed-off-by: Chen Yang <2082464740@qq.com>
…#722) Signed-off-by: ZeldaHuang <hzm414167@alibaba-inc.com>
ref #420
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)