Skip to content

[Features] Qwen3.6 PartialRoPE supports#568

Open
PatchouliTIS wants to merge 6 commits into
vllm-project:mainfrom
PatchouliTIS:patchy/qwen36_support_pr
Open

[Features] Qwen3.6 PartialRoPE supports#568
PatchouliTIS wants to merge 6 commits into
vllm-project:mainfrom
PatchouliTIS:patchy/qwen36_support_pr

Conversation

@PatchouliTIS
Copy link
Copy Markdown

@PatchouliTIS PatchouliTIS commented Jun 3, 2026

PR Description: Qwen3.6 and Multimodal Verifiers Support

Purpose

Add end-to-end training support for Qwen3.6 (Hybrid Attention, MoE) and other multimodal verifiers (Qwen3-Omni, Qwen3.5-MoE) to the speculators training pipeline. This includes proper handling of MRoPE position encoding, partial rotary factors, multimodal data preprocessing with sidecar tensors, and a fix for the token-id mismatch that caused 100% sample rejection during online hidden-states generation.

Description

Qwen3.6 Hybrid Attention & RoPE Training (eagle3/core.py, scripts/train.py)

  • Add _select_rotary_emb_class and PartialMRoPE to support partial_rotary_factor < 1 in MRoPE-aware draft models.
  • Add --draft-mrope-full-head-hack (default on): rescales mrope_section by 1/partial_rotary_factor and pins partial_rotary_factor=1.0 to ensure bit-equivalent rotation between HF's rotate_half (training) and vLLM's neox-partial rotation (inference).
  • Add --draft-rope-scaling, --draft-rope-theta, --draft-max-position-embeddings for explicit RoPE override.
  • Add --draft-intermediate-size, --draft-num-attention-heads, --draft-num-key-value-heads, --draft-head-dim for MoE drafter geometry override.
  • Support unwrap_verifier_text_config for nested multimodal configs (thinker_config.text_config).

Multimodal Data Pipeline (preprocessing.py, train/data.py)

  • Add multimodal sidecar persistence: _save_multimodal_sidecar writes image_grid_thw, video_grid_thw, etc. to .safetensors files.
  • Add _build_multimodal_loss_mask: zeroes loss at placeholder token positions (image/video/audio pad tokens).
  • Add _loss_mask_from_assistant_token_spans and placeholder-aware _loss_mask_from_ids_fallback for accurate multimodal assistant mask alignment.
  • ArrowDataset now calls .with_format(None) after load_from_disk to expose multimodal metadata columns hidden by prepare_data.py's persisted set_format.
  • Add _make_rope_index_fn supporting both Qwen3-Omni (nested thinker_config) and Qwen3.5/3.6 MoE (flat vision_config + text_config) config layouts for 3D MRoPE position_ids generation.
  • collate_fn supports mixed 1D/3D position_ids batching.

Online Hidden-States Generation Fix (vllm_client.py, train/data.py)

  • Root cause: chat.completions.create(messages=...) re-renders the chat template server-side, introducing an off-by-one newline token at vision-placeholder boundaries vs. prepare_data.py's HF processor tokenization. This caused extract_output's strict token-id check to reject 100% of samples → loss=0.
  • Fix: Default to token-id Completions path — send pre-tokenized input_ids + multi_modal_data (extracted media URLs) via extra_body, bypassing server-side template rendering entirely. messages retained only as opt-in fallback.
  • Add _collect_mm_payload_from_messages to flatten media URLs from persisted messages into vLLM's multi_modal_data format.

vLLM Launch Script (scripts/launch_vllm.py)

  • Add --layer-ids alias for --target-layer-ids.
  • Support thinker_config unwrap for Qwen3-Omni/Qwen3.6 multimodal verifiers.
  • Add resolve_layer_ids separating training target layer ids from vLLM extraction layer ids.
  • Flatten text-backbone config into hf_config_overrides to fix get_hf_text_config() failures with nested multimodal configs.

Misc

  • EAGLE3_LOSS_CE_WEIGHT env var for optional KL+CE mixed loss (eagle3/metrics.py).
  • vocab_mapping.py: support thinker_config.text_config.vocab_size unwrap.
  • eagle3/data.py: shift_batch handles 2D [3, seq_len] MRoPE position_ids.

Related Issue

N/A — Internal support for Qwen3.6-35B-A3B Eagle3/DFlash drafter training.

Tests

  • python3 -m py_compile passes on all 10 modified files.
  • prepare_data.py successfully preprocesses multimodal (image+text) dataset with sidecar output.
  • train.py with --on-missing generate successfully generates hidden states via token-id Completions path (no more token-id mismatch rejection).
  • End-to-end Eagle3 training on Qwen3.6-35B-A3B with --draft-mrope-full-head-hack produces non-zero loss and expected acceptance metrics.

Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan/results, such as providing test command and pasting the results.
  • (Optional) The necessary documentation update.
  • I (a human) have written or reviewed the code in this PR to the best of my ability.

Summary by CodeRabbit

  • New Features

    • Enhanced multimodal support for image, video, and audio data preprocessing and training
    • Configurable draft model parameters (intermediate size, attention heads, RoPE settings) via CLI
    • Improved token-id validation and multimodal data handling in speculative decoding requests
    • MRoPE-aware rotary embedding support for improved model architecture compatibility
  • Bug Fixes

    • Fixed position_ids alignment for multimodal models
    • Enhanced loss mask computation for robust training
  • Chores

    • Updated CLI argument from --target-layer-ids to --layer-ids

PatchouliTaisa added 3 commits June 3, 2026 10:44
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 3, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

This PR adds end-to-end multimodal support: vLLM layer-ID resolution and validation, multimodal preprocessing with safetensor sidecars and loss-mask alignment, vLLM API routing for token-id parity, multimodal-aware training data loading with RoPE indexing and collation, drafter configuration derivation with MRoPE support, and model-level rotary and loss adjustments.

Changes

Multimodal Speculative Extraction Pipeline

Layer / File(s) Summary
vLLM layer-ID resolution for speculative extraction
scripts/launch_vllm.py
Helper functions unwrap verifier configs, compute DeepStack-aware default target ids, deduplicate/validate ids, and resolve training vs extraction ids with --include-last-layer semantics; main() uses AutoConfig + resolve_layer_ids() and builds speculative vLLM config with extraction ids.
Multimodal data preprocessing and sidecars
scripts/prepare_data.py, src/speculators/data_generation/preprocessing.py
Adds multimodal constants/helpers, serializes multimodal messages, builds safetensor sidecars, and computes multimodal-aware loss masks (subsequence + regex fallback). _preprocess_batch writes sidecars and validates samples; dataset builders accept placeholder_token_ids and multimodal_output_dir.
vLLM client routing and token-id parity
src/speculators/data_generation/vllm_client.py
extract_output gains trust_server_token_ids. ClientItem adds multi_modal_data. Hidden-state generation can use Chat Completions (trust server token ids) or Completions with prompt=token_ids and forwarded multimodal data to preserve token-id alignment.
Training dataset loading, RoPE indexing, collation
src/speculators/train/data.py, src/speculators/train/vocab_mapping.py, scripts/data_generation_offline.py
Adds multimodal sidecar keys, builds ClientItem multimodal payloads, creates RoPE index adapters, implements _build_position_ids for multimodal MRoPE, preserves metadata via with_format(None), validates/loads hidden-state token ids and sidecars, and updates collation to support 2D/mrope position_ids. get_target_vocab_size() handles nested thinker_config.
Drafter parameter derivation and CLI wiring
scripts/train.py
Adds unwrap_verifier_text_config() and get_default_draft_intermediate_size(). Extends create_transformer_layer_config() with drafter overrides (intermediate size, heads/GQA, head_dim, rope params, mrope_full_head_hack), reconstructs/validates geometry and RoPE parameters, and threads new CLI flags and verifier_name_or_path through training/validation dataset construction.
Model rotary selection and eagle3 loss
src/speculators/models/eagle3/core.py, src/speculators/models/eagle3/rotary_partial.py, src/speculators/models/eagle3/data.py, src/speculators/models/eagle3/metrics.py
Adds MRoPE-aware rotary selection (including Qwen-Omni wrapping), dynamic partial-MRoPE class creation, partial-neox rotary implementation with install/uninstall helpers, conditional position_ids slicing in shift_batch, and eagle3_loss() blending KL and CE via EAGLE3_LOSS_CE_WEIGHT (default loss in compute_metrics).
Unit tests for partial-neox rotary
tests/unit/models/test_eagle3_rotary_partial.py
New tests for full/partial rotation parity, tail-channel pass-through, idempotent patch install/uninstall, and defensive validation.

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • shanjiaz
  • dsikka

"A rabbit in code, with whiskers bright,
I stitched the layers through the night.
Sidecars hum and rotaries spin,
Hidden states hop out grins within.
🐇🌙"

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 43.27% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Features] Qwen3.6 PartialRoPE supports' clearly identifies the main feature addition—Qwen3.6 PartialRoPE support—which aligns with the PR's primary objective of adding end-to-end training support for Qwen3.6 with PartialRoPE (partial_rotary_factor < 1) and related multimodal verifier enhancements.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mergify
Copy link
Copy Markdown

mergify Bot commented Jun 3, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 Require two reviews

Waiting for

  • #approved-reviews-by >= 2
This rule is failing.

PRs labelled "two-reviews" must have at least two approving reviews before merging.

  • #approved-reviews-by >= 2

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (4)
scripts/launch_vllm.py (2)

176-195: 💤 Low value

Duplicated resolve_layer_ids call can be simplified.

Both branches of the if args.target_layer_ids conditional call resolve_layer_ids with identical arguments. The only difference is the warning for custom IDs. Consider extracting the call before the conditional.

Suggested simplification
-    if args.target_layer_ids:
-        training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids(
-            args, multimodal_config, num_hidden_layers
-        )
+    training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids(
+        args, multimodal_config, num_hidden_layers
+    )
+
+    if args.target_layer_ids:
         warnings.warn(
             "Using custom target layer ids. Pass "
             f"{training_target_layer_ids} to the training script; vLLM will "
             f"extract {extraction_layer_ids}.",
             stacklevel=2,
         )
-    else:
-        training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids(
-            args, multimodal_config, num_hidden_layers
-        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/launch_vllm.py` around lines 176 - 195, Call resolve_layer_ids(args,
multimodal_config, num_hidden_layers) once before the if and assign its return
to training_target_layer_ids, extraction_layer_ids, layer_id_source; then keep
the if args.target_layer_ids only to emit the warnings.warn(...) using those
variables. This removes the duplicated resolve_layer_ids call while preserving
the custom-ID warning behavior and the subsequent print of layer ids.

56-63: 💤 Low value

Default layer IDs may become negative for shallow models.

candidate_layer_ids = [2, num_hidden_layers // 2, num_hidden_layers - 3] can produce negative values when num_hidden_layers < 3. Although validate_layer_ids would catch < 0, the error message would be confusing since users didn't explicitly pass these IDs.

Suggested guard for shallow models
 def get_default_target_layer_ids(multimodal_config, num_hidden_layers: int) -> list[int]:
     """Return default auxiliary layer ids used by training."""
+    if num_hidden_layers < 4:
+        # Shallow model: just use layer 0 (embedding output) or mid-layer
+        return [max(0, num_hidden_layers // 2)]
     deepstack_layers = set(get_deepstack_visual_indexes(multimodal_config))
     candidate_layer_ids = [2, num_hidden_layers // 2, num_hidden_layers - 3]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/launch_vllm.py` around lines 56 - 63, The
get_default_target_layer_ids function can produce negative layer ids for shallow
models; update the candidate generation and post-processing so candidates are
clamped into the valid range [0, num_hidden_layers-1] (e.g., replace raw
expressions in candidate_layer_ids with guarded values or apply a clamp step),
then apply the deepstack adjustment (the existing layer_id - 1 rule) while
ensuring the adjusted id doesn't become negative, and finally deduplicate/filter
out any out-of-range ids before returning; reference
get_default_target_layer_ids, candidate_layer_ids, and the deepstack adjustment
to locate where to add the guards so validate_layer_ids no longer has to surface
confusing errors for defaults.
src/speculators/models/eagle3/metrics.py (1)

18-36: 💤 Low value

Environment variable parsed at import time may cause inconsistent behavior across workers.

EAGLE3_LOSS_CE_WEIGHT is parsed once at module import. In distributed training, if env vars differ across workers or change after import, the loss function will use stale values. This is likely acceptable since env vars are typically set before process launch, but worth documenting.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/models/eagle3/metrics.py` around lines 18 - 36, The module
currently parses EAGLE3_LOSS_CE_WEIGHT at import time using
_EAGLE3_LOSS_CE_WEIGHT_RAW and _EAGLE3_LOSS_CE_WEIGHT which can become stale
across worker processes; change to a lazy getter (e.g., add a function
get_eagle3_loss_ce_weight()) that reads os.getenv each time,
validates/parses/clamps the value exactly as the current logic does, and replace
direct references to _EAGLE3_LOSS_CE_WEIGHT in loss code with calls to
get_eagle3_loss_ce_weight(); alternatively allow injection of the weight via a
parameter to the relevant loss constructor/function to avoid relying on
import-time globals.
src/speculators/train/data.py (1)

277-309: ⚖️ Poor tradeoff

Fragile coupling to internal transformers implementation details.

This code binds internal model methods (get_llm_pos_ids_for_vision, get_rope_index) to a SimpleNamespace dummy object. These internal APIs are not part of the public transformers contract and may break with minor version upgrades.

Consider adding a version check or wrapping in a try/except with a more informative fallback warning.

What is the current version of transformers and does Qwen3OmniMoeThinkerForConditionalGeneration have get_rope_index method?
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/train/data.py` around lines 277 - 309, The
_make_rope_index_fn_qwen3_omni function tightly couples to internal transformer
internals by binding Qwen3OmniMoeThinkerForConditionalGeneration methods to a
SimpleNamespace; update it to guard for missing/internal API changes by checking
transformers.__version__ and the presence of the methods before binding (use
hasattr for get_llm_pos_ids_for_vision and get_rope_index), and if the methods
or an acceptable version are not present return None while logging a clear
warning that includes transformers.__version__ and the missing symbol names
(refer to _make_rope_index_fn_qwen3_omni,
Qwen3OmniMoeThinkerForConditionalGeneration, get_llm_pos_ids_for_vision,
get_rope_index, SimpleNamespace, MethodType).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@scripts/train.py`:
- Around line 752-757: The parser argument for "--draft-rope-scaling" uses
json.loads but the module isn't imported, causing a NameError at runtime; add an
import for the json module near the top imports (so json is available to the
parser.add_argument lambda and any other use), ensuring the json.loads call in
the "--draft-rope-scaling" handler works correctly.

In `@src/speculators/train/data.py`:
- Around line 597-610: The call in _get_raw_data uses
self._maybe_load_hs_file(index) but _maybe_load_hs_file is a module-level
function that expects a file_path: Path, not an index; replace the incorrect
call with a call that passes the multimodal file path from the fetched row (e.g.
_maybe_load_hs_file(row["mm_file"]) converted to Path) or alternatively
implement a small instance wrapper method named _maybe_load_hs_file(self,
file_path) that forwards to the module function; ensure the returned loaded_hs
handling remains the same and still handles None results.
- Around line 803-818: The MRoPE branch in create_collate_fn is producing a
rank-3 position_ids tensor (due to expanding pos with expand(3, -1) and
unsqueezing) while the non-MRoPE branch produces rank-2, causing downstream
mis-slicing in shift_batch; fix create_collate_fn so both branches produce the
same shape (match the non-MRoPE output). Specifically, in the has_mrope branch
(the loop over batch building position_ids) stop expanding 1-D pos to (3,
seq_len) and instead ensure you concatenate along dim=0 to produce a rank-2
tensor, then call pad_last_dim_to_length (not pad to create an extra leading
dim) and finally unsqueeze(0) just like the else branch; keep references to
position_ids, pad_last_dim_to_length, slice_and_pad_to_length, and
create_collate_fn so shift_batch in eagle3/data.py sees a consistent shape.

---

Nitpick comments:
In `@scripts/launch_vllm.py`:
- Around line 176-195: Call resolve_layer_ids(args, multimodal_config,
num_hidden_layers) once before the if and assign its return to
training_target_layer_ids, extraction_layer_ids, layer_id_source; then keep the
if args.target_layer_ids only to emit the warnings.warn(...) using those
variables. This removes the duplicated resolve_layer_ids call while preserving
the custom-ID warning behavior and the subsequent print of layer ids.
- Around line 56-63: The get_default_target_layer_ids function can produce
negative layer ids for shallow models; update the candidate generation and
post-processing so candidates are clamped into the valid range [0,
num_hidden_layers-1] (e.g., replace raw expressions in candidate_layer_ids with
guarded values or apply a clamp step), then apply the deepstack adjustment (the
existing layer_id - 1 rule) while ensuring the adjusted id doesn't become
negative, and finally deduplicate/filter out any out-of-range ids before
returning; reference get_default_target_layer_ids, candidate_layer_ids, and the
deepstack adjustment to locate where to add the guards so validate_layer_ids no
longer has to surface confusing errors for defaults.

In `@src/speculators/models/eagle3/metrics.py`:
- Around line 18-36: The module currently parses EAGLE3_LOSS_CE_WEIGHT at import
time using _EAGLE3_LOSS_CE_WEIGHT_RAW and _EAGLE3_LOSS_CE_WEIGHT which can
become stale across worker processes; change to a lazy getter (e.g., add a
function get_eagle3_loss_ce_weight()) that reads os.getenv each time,
validates/parses/clamps the value exactly as the current logic does, and replace
direct references to _EAGLE3_LOSS_CE_WEIGHT in loss code with calls to
get_eagle3_loss_ce_weight(); alternatively allow injection of the weight via a
parameter to the relevant loss constructor/function to avoid relying on
import-time globals.

In `@src/speculators/train/data.py`:
- Around line 277-309: The _make_rope_index_fn_qwen3_omni function tightly
couples to internal transformer internals by binding
Qwen3OmniMoeThinkerForConditionalGeneration methods to a SimpleNamespace; update
it to guard for missing/internal API changes by checking
transformers.__version__ and the presence of the methods before binding (use
hasattr for get_llm_pos_ids_for_vision and get_rope_index), and if the methods
or an acceptable version are not present return None while logging a clear
warning that includes transformers.__version__ and the missing symbol names
(refer to _make_rope_index_fn_qwen3_omni,
Qwen3OmniMoeThinkerForConditionalGeneration, get_llm_pos_ids_for_vision,
get_rope_index, SimpleNamespace, MethodType).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 268e87a4-974a-495a-a822-01c74c77e46e

📥 Commits

Reviewing files that changed from the base of the PR and between 927c9c8 and 55e7f98.

📒 Files selected for processing (10)
  • scripts/launch_vllm.py
  • scripts/prepare_data.py
  • scripts/train.py
  • src/speculators/data_generation/preprocessing.py
  • src/speculators/data_generation/vllm_client.py
  • src/speculators/models/eagle3/core.py
  • src/speculators/models/eagle3/data.py
  • src/speculators/models/eagle3/metrics.py
  • src/speculators/train/data.py
  • src/speculators/train/vocab_mapping.py

Comment thread scripts/train.py
Comment thread src/speculators/train/data.py Outdated
Comment on lines +803 to +818
if has_mrope:
position_ids = []
for sample in batch: # type: ignore[assignment]
pos = sample["position_ids"]
if pos.ndim == 1:
pos = pos.unsqueeze(0).expand(3, -1)
position_ids.append(pos)
collated_positions = torch.cat(position_ids, dim=-1)
collated_data["position_ids"] = pad_last_dim_to_length(
collated_positions, max_len
).unsqueeze(1)
else:
collated_positions = torch.cat([b["position_ids"] for b in batch], dim=0) # type: ignore[index]
collated_data["position_ids"] = slice_and_pad_to_length(
collated_positions, max_len
).unsqueeze(0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for how position_ids is consumed in the model forward pass
rg -n --type=py 'position_ids' src/speculators/models/eagle3/ -A3 -B3

Repository: vllm-project/speculators

Length of output: 6804


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Show collate logic area in src/speculators/train/data.py around the reported lines
sed -n '740,860p' src/speculators/train/data.py

# Inspect Eagle3 data preprocessing that slices position_ids
sed -n '1,120p' src/speculators/models/eagle3/data.py

# Inspect Eagle3 core around rotary_emb usage and any position_ids shape handling
# (rg output indicated rotary_emb call around lines ~321 and position_ids comments around ~242-264)
sed -n '220,420p' src/speculators/models/eagle3/core.py

Repository: vllm-project/speculators

Length of output: 13575


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find rotary_emb implementation in eagle3/core.py (or imported modules)
rg -n --type=py "def .*rotary_emb|rotary_emb" src/speculators/models/eagle3/core.py
# Also search for MRoPE/has_mrope/3-row position handling terms across eagle3 model code
rg -n --type=py "MRoPE|mrope|3[, ]|position_ids\.ndim" src/speculators/models/eagle3 -S

Repository: vllm-project/speculators

Length of output: 4697


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate where create_collate_fn (or collate function) is defined to confirm exact lines and has_mrope condition
rg -n --type=py "def create_collate_fn|has_mrope|pad_last_dim_to_length|slice_and_pad_to_length" src/speculators/train/data.py
sed -n '760,835p' src/speculators/train/data.py

Repository: vllm-project/speculators

Length of output: 3559


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "==== src/speculators/train/data.py (collate section) ===="
sed -n '740,860p' src/speculators/train/data.py

echo
echo "==== src/speculators/models/eagle3/data.py (position_ids slicing) ===="
sed -n '1,120p' src/speculators/models/eagle3/data.py

echo
echo "==== src/speculators/models/eagle3/core.py (rotary_emb + position_ids handling) ===="
sed -n '200,420p' src/speculators/models/eagle3/core.py

echo
echo "==== Search rotary_emb and MRoPE/position_ids.ndim usage in eagle3 ===="
rg -n --type=py "rotary_emb|MRoPE|mrope|position_ids\.ndim" src/speculators/models/eagle3 -S

Repository: vllm-project/speculators

Length of output: 16712


Fix/justify the MRoPE position_ids collation shape difference ([3, 1, max_len] vs [1, max_len])

src/speculators/train/data.py:create_collate_fn produces rank-3 position_ids for MRoPE (pos.ndim==1 → expand to [3, seq_len], then pad_last_dim_to_length(...).unsqueeze(1)[3, 1, max_len]) but rank-2 for non-MRoPE ([1, max_len]).

Eagle3 training code expects position_ids shapes it can slice in src/speculators/models/eagle3/data.py:shift_batch (it only special-cases position_ids.ndim == 2; otherwise it slices position_ids[1:]). If shift_batch ever receives the collated MRoPE tensor, it will take the wrong slice path and corrupt position_ids.

Make position_ids shape consistent across both branches (or add explicit handling for position_ids.ndim == 3 where it’s consumed).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/train/data.py` around lines 803 - 818, The MRoPE branch in
create_collate_fn is producing a rank-3 position_ids tensor (due to expanding
pos with expand(3, -1) and unsqueezing) while the non-MRoPE branch produces
rank-2, causing downstream mis-slicing in shift_batch; fix create_collate_fn so
both branches produce the same shape (match the non-MRoPE output). Specifically,
in the has_mrope branch (the loop over batch building position_ids) stop
expanding 1-D pos to (3, seq_len) and instead ensure you concatenate along dim=0
to produce a rank-2 tensor, then call pad_last_dim_to_length (not pad to create
an extra leading dim) and finally unsqueeze(0) just like the else branch; keep
references to position_ids, pad_last_dim_to_length, slice_and_pad_to_length, and
create_collate_fn so shift_batch in eagle3/data.py sees a consistent shape.

PatchouliTaisa added 3 commits June 3, 2026 15:42
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
scripts/data_generation_offline.py (1)

243-265: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Fix multimodal offline generation to use vLLM Chat Completions (vision encoder)

scripts/data_generation_offline.py calls generate_hidden_states_async(...) without use_chat_completions, so it always takes the Completions path (even when build_client_item provides multimodal payloads). generate_hidden_states_async explicitly notes that the vision encoder path is only used when use_chat_completions=True and client_item["messages"] is present, and the online implementation in src/speculators/train/data.py mirrors this for multimodal rows.

Update the worker to set use_chat_completions like the online code (multimodal && item.get("messages") is not None), pass it to generate_hidden_states_async, and handle the tuple return (path, server_token_ids)—using server_token_ids for check_safetensors_file instead of item["input_ids"].

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@scripts/data_generation_offline.py` around lines 243 - 265, The generation
call currently always uses the Completions path; set use_chat_completions =
multimodal and item.get("messages") is not None and pass it into
generate_hidden_states_async (i.e., await generate_hidden_states_async(...,
use_chat_completions=use_chat_completions)); unpack the tuple return as
hidden_states_path, server_token_ids = await generate_hidden_states_async(...)
and then when validate_outputs is true call check_safetensors_file with
server_token_ids (not item["input_ids"]) while keeping the existing lock/writes
(wait_for_lock_async, shutil.move to target_hidden_states_path) logic intact.
🧹 Nitpick comments (2)
src/speculators/models/eagle3/rotary_partial.py (1)

72-92: 💤 Low value

Consider documenting concurrent-call behavior.

The function modifies global module state (modeling_llama.apply_rotary_pos_emb and modeling_qwen3.apply_rotary_pos_emb) without synchronization. If called concurrently from multiple threads during model initialization, race conditions could occur between the _INSTALLED check (line 78) and the assignments (line 90).

If concurrent installation is possible in your deployment environment, consider adding a lock or documenting that callers must serialize access.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/models/eagle3/rotary_partial.py` around lines 72 - 92,
install_partial_neox_rotary is racy: the check/use of the module-global
_INSTALLED and the subsequent reassignment of
modeling_llama.apply_rotary_pos_emb and modeling_qwen3.apply_rotary_pos_emb can
race if called concurrently; protect the install/uninstall critical section with
a module-level lock (e.g., threading.Lock) so the check of _INSTALLED, caching
of module._speculators_original_apply_rotary_pos_emb, and assignment of
module.apply_rotary_pos_emb are atomic, or alternatively document that callers
must serialize calls — modify install_partial_neox_rotary to acquire the lock at
the start and release after setting _INSTALLED and all module assignments (and
do similarly in any uninstall path if present).
src/speculators/data_generation/preprocessing.py (1)

879-892: 💤 Low value

assistant_pattern mutation affects subsequent samples in batch.

When assistant_pattern is None and the first multimodal sample triggers _detect_assistant_pattern, the pattern is reassigned in place. This side effect persists for all subsequent samples in the same batch, which is likely intentional. However, this mutation happens inside the per-sample loop and the pattern detection could fail for some edge-case samples. Consider detecting the pattern once before the loop or documenting this intentional mutation.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/speculators/data_generation/preprocessing.py` around lines 879 - 892, The
code mutates assistant_pattern inside the per-sample loop by calling
_detect_assistant_pattern(processor) on the first multimodal sample, causing the
detected pattern to persist across subsequent samples; to fix, move the
detection out of the per-sample loop so assistant_pattern is computed once
(e.g., call _detect_assistant_pattern(processor) before iterating samples when
assistant_pattern is None) and then use that value inside the loop for
_loss_mask_from_assistant_token_spans and _loss_mask_from_ids_fallback, or if
you prefer to keep the in-loop behavior, add a clear comment on
assistant_pattern, _detect_assistant_pattern, and its intentional mutation to
document the side effect.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@scripts/data_generation_offline.py`:
- Around line 243-265: The generation call currently always uses the Completions
path; set use_chat_completions = multimodal and item.get("messages") is not None
and pass it into generate_hidden_states_async (i.e., await
generate_hidden_states_async(..., use_chat_completions=use_chat_completions));
unpack the tuple return as hidden_states_path, server_token_ids = await
generate_hidden_states_async(...) and then when validate_outputs is true call
check_safetensors_file with server_token_ids (not item["input_ids"]) while
keeping the existing lock/writes (wait_for_lock_async, shutil.move to
target_hidden_states_path) logic intact.

---

Nitpick comments:
In `@src/speculators/data_generation/preprocessing.py`:
- Around line 879-892: The code mutates assistant_pattern inside the per-sample
loop by calling _detect_assistant_pattern(processor) on the first multimodal
sample, causing the detected pattern to persist across subsequent samples; to
fix, move the detection out of the per-sample loop so assistant_pattern is
computed once (e.g., call _detect_assistant_pattern(processor) before iterating
samples when assistant_pattern is None) and then use that value inside the loop
for _loss_mask_from_assistant_token_spans and _loss_mask_from_ids_fallback, or
if you prefer to keep the in-loop behavior, add a clear comment on
assistant_pattern, _detect_assistant_pattern, and its intentional mutation to
document the side effect.

In `@src/speculators/models/eagle3/rotary_partial.py`:
- Around line 72-92: install_partial_neox_rotary is racy: the check/use of the
module-global _INSTALLED and the subsequent reassignment of
modeling_llama.apply_rotary_pos_emb and modeling_qwen3.apply_rotary_pos_emb can
race if called concurrently; protect the install/uninstall critical section with
a module-level lock (e.g., threading.Lock) so the check of _INSTALLED, caching
of module._speculators_original_apply_rotary_pos_emb, and assignment of
module.apply_rotary_pos_emb are atomic, or alternatively document that callers
must serialize calls — modify install_partial_neox_rotary to acquire the lock at
the start and release after setting _INSTALLED and all module assignments (and
do similarly in any uninstall path if present).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c307ba3e-c7d0-4ce9-9eff-b2d6071d5c6f

📥 Commits

Reviewing files that changed from the base of the PR and between 55e7f98 and 49df015.

📒 Files selected for processing (9)
  • scripts/data_generation_offline.py
  • scripts/launch_vllm.py
  • scripts/train.py
  • src/speculators/data_generation/preprocessing.py
  • src/speculators/data_generation/vllm_client.py
  • src/speculators/models/eagle3/core.py
  • src/speculators/models/eagle3/rotary_partial.py
  • src/speculators/train/data.py
  • tests/unit/models/test_eagle3_rotary_partial.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • src/speculators/models/eagle3/core.py
  • scripts/train.py
  • scripts/launch_vllm.py
  • src/speculators/train/data.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant