[Features] Qwen3.6 PartialRoPE supports#568
Conversation
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
📝 WalkthroughWalkthroughThis PR adds end-to-end multimodal support: vLLM layer-ID resolution and validation, multimodal preprocessing with safetensor sidecars and loss-mask alignment, vLLM API routing for token-id parity, multimodal-aware training data loading with RoPE indexing and collation, drafter configuration derivation with MRoPE support, and model-level rotary and loss adjustments. ChangesMultimodal Speculative Extraction Pipeline
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 Require two reviewsWaiting for
This rule is failing.PRs labelled "two-reviews" must have at least two approving reviews before merging.
|
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (4)
scripts/launch_vllm.py (2)
176-195: 💤 Low valueDuplicated
resolve_layer_idscall can be simplified.Both branches of the
if args.target_layer_idsconditional callresolve_layer_idswith identical arguments. The only difference is the warning for custom IDs. Consider extracting the call before the conditional.Suggested simplification
- if args.target_layer_ids: - training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids( - args, multimodal_config, num_hidden_layers - ) + training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids( + args, multimodal_config, num_hidden_layers + ) + + if args.target_layer_ids: warnings.warn( "Using custom target layer ids. Pass " f"{training_target_layer_ids} to the training script; vLLM will " f"extract {extraction_layer_ids}.", stacklevel=2, ) - else: - training_target_layer_ids, extraction_layer_ids, layer_id_source = resolve_layer_ids( - args, multimodal_config, num_hidden_layers - )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/launch_vllm.py` around lines 176 - 195, Call resolve_layer_ids(args, multimodal_config, num_hidden_layers) once before the if and assign its return to training_target_layer_ids, extraction_layer_ids, layer_id_source; then keep the if args.target_layer_ids only to emit the warnings.warn(...) using those variables. This removes the duplicated resolve_layer_ids call while preserving the custom-ID warning behavior and the subsequent print of layer ids.
56-63: 💤 Low valueDefault layer IDs may become negative for shallow models.
candidate_layer_ids = [2, num_hidden_layers // 2, num_hidden_layers - 3]can produce negative values whennum_hidden_layers < 3. Althoughvalidate_layer_idswould catch< 0, the error message would be confusing since users didn't explicitly pass these IDs.Suggested guard for shallow models
def get_default_target_layer_ids(multimodal_config, num_hidden_layers: int) -> list[int]: """Return default auxiliary layer ids used by training.""" + if num_hidden_layers < 4: + # Shallow model: just use layer 0 (embedding output) or mid-layer + return [max(0, num_hidden_layers // 2)] deepstack_layers = set(get_deepstack_visual_indexes(multimodal_config)) candidate_layer_ids = [2, num_hidden_layers // 2, num_hidden_layers - 3]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/launch_vllm.py` around lines 56 - 63, The get_default_target_layer_ids function can produce negative layer ids for shallow models; update the candidate generation and post-processing so candidates are clamped into the valid range [0, num_hidden_layers-1] (e.g., replace raw expressions in candidate_layer_ids with guarded values or apply a clamp step), then apply the deepstack adjustment (the existing layer_id - 1 rule) while ensuring the adjusted id doesn't become negative, and finally deduplicate/filter out any out-of-range ids before returning; reference get_default_target_layer_ids, candidate_layer_ids, and the deepstack adjustment to locate where to add the guards so validate_layer_ids no longer has to surface confusing errors for defaults.src/speculators/models/eagle3/metrics.py (1)
18-36: 💤 Low valueEnvironment variable parsed at import time may cause inconsistent behavior across workers.
EAGLE3_LOSS_CE_WEIGHTis parsed once at module import. In distributed training, if env vars differ across workers or change after import, the loss function will use stale values. This is likely acceptable since env vars are typically set before process launch, but worth documenting.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/speculators/models/eagle3/metrics.py` around lines 18 - 36, The module currently parses EAGLE3_LOSS_CE_WEIGHT at import time using _EAGLE3_LOSS_CE_WEIGHT_RAW and _EAGLE3_LOSS_CE_WEIGHT which can become stale across worker processes; change to a lazy getter (e.g., add a function get_eagle3_loss_ce_weight()) that reads os.getenv each time, validates/parses/clamps the value exactly as the current logic does, and replace direct references to _EAGLE3_LOSS_CE_WEIGHT in loss code with calls to get_eagle3_loss_ce_weight(); alternatively allow injection of the weight via a parameter to the relevant loss constructor/function to avoid relying on import-time globals.src/speculators/train/data.py (1)
277-309: ⚖️ Poor tradeoffFragile coupling to internal
transformersimplementation details.This code binds internal model methods (
get_llm_pos_ids_for_vision,get_rope_index) to aSimpleNamespacedummy object. These internal APIs are not part of the public transformers contract and may break with minor version upgrades.Consider adding a version check or wrapping in a try/except with a more informative fallback warning.
What is the current version of transformers and does Qwen3OmniMoeThinkerForConditionalGeneration have get_rope_index method?🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/speculators/train/data.py` around lines 277 - 309, The _make_rope_index_fn_qwen3_omni function tightly couples to internal transformer internals by binding Qwen3OmniMoeThinkerForConditionalGeneration methods to a SimpleNamespace; update it to guard for missing/internal API changes by checking transformers.__version__ and the presence of the methods before binding (use hasattr for get_llm_pos_ids_for_vision and get_rope_index), and if the methods or an acceptable version are not present return None while logging a clear warning that includes transformers.__version__ and the missing symbol names (refer to _make_rope_index_fn_qwen3_omni, Qwen3OmniMoeThinkerForConditionalGeneration, get_llm_pos_ids_for_vision, get_rope_index, SimpleNamespace, MethodType).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@scripts/train.py`:
- Around line 752-757: The parser argument for "--draft-rope-scaling" uses
json.loads but the module isn't imported, causing a NameError at runtime; add an
import for the json module near the top imports (so json is available to the
parser.add_argument lambda and any other use), ensuring the json.loads call in
the "--draft-rope-scaling" handler works correctly.
In `@src/speculators/train/data.py`:
- Around line 597-610: The call in _get_raw_data uses
self._maybe_load_hs_file(index) but _maybe_load_hs_file is a module-level
function that expects a file_path: Path, not an index; replace the incorrect
call with a call that passes the multimodal file path from the fetched row (e.g.
_maybe_load_hs_file(row["mm_file"]) converted to Path) or alternatively
implement a small instance wrapper method named _maybe_load_hs_file(self,
file_path) that forwards to the module function; ensure the returned loaded_hs
handling remains the same and still handles None results.
- Around line 803-818: The MRoPE branch in create_collate_fn is producing a
rank-3 position_ids tensor (due to expanding pos with expand(3, -1) and
unsqueezing) while the non-MRoPE branch produces rank-2, causing downstream
mis-slicing in shift_batch; fix create_collate_fn so both branches produce the
same shape (match the non-MRoPE output). Specifically, in the has_mrope branch
(the loop over batch building position_ids) stop expanding 1-D pos to (3,
seq_len) and instead ensure you concatenate along dim=0 to produce a rank-2
tensor, then call pad_last_dim_to_length (not pad to create an extra leading
dim) and finally unsqueeze(0) just like the else branch; keep references to
position_ids, pad_last_dim_to_length, slice_and_pad_to_length, and
create_collate_fn so shift_batch in eagle3/data.py sees a consistent shape.
---
Nitpick comments:
In `@scripts/launch_vllm.py`:
- Around line 176-195: Call resolve_layer_ids(args, multimodal_config,
num_hidden_layers) once before the if and assign its return to
training_target_layer_ids, extraction_layer_ids, layer_id_source; then keep the
if args.target_layer_ids only to emit the warnings.warn(...) using those
variables. This removes the duplicated resolve_layer_ids call while preserving
the custom-ID warning behavior and the subsequent print of layer ids.
- Around line 56-63: The get_default_target_layer_ids function can produce
negative layer ids for shallow models; update the candidate generation and
post-processing so candidates are clamped into the valid range [0,
num_hidden_layers-1] (e.g., replace raw expressions in candidate_layer_ids with
guarded values or apply a clamp step), then apply the deepstack adjustment (the
existing layer_id - 1 rule) while ensuring the adjusted id doesn't become
negative, and finally deduplicate/filter out any out-of-range ids before
returning; reference get_default_target_layer_ids, candidate_layer_ids, and the
deepstack adjustment to locate where to add the guards so validate_layer_ids no
longer has to surface confusing errors for defaults.
In `@src/speculators/models/eagle3/metrics.py`:
- Around line 18-36: The module currently parses EAGLE3_LOSS_CE_WEIGHT at import
time using _EAGLE3_LOSS_CE_WEIGHT_RAW and _EAGLE3_LOSS_CE_WEIGHT which can
become stale across worker processes; change to a lazy getter (e.g., add a
function get_eagle3_loss_ce_weight()) that reads os.getenv each time,
validates/parses/clamps the value exactly as the current logic does, and replace
direct references to _EAGLE3_LOSS_CE_WEIGHT in loss code with calls to
get_eagle3_loss_ce_weight(); alternatively allow injection of the weight via a
parameter to the relevant loss constructor/function to avoid relying on
import-time globals.
In `@src/speculators/train/data.py`:
- Around line 277-309: The _make_rope_index_fn_qwen3_omni function tightly
couples to internal transformer internals by binding
Qwen3OmniMoeThinkerForConditionalGeneration methods to a SimpleNamespace; update
it to guard for missing/internal API changes by checking
transformers.__version__ and the presence of the methods before binding (use
hasattr for get_llm_pos_ids_for_vision and get_rope_index), and if the methods
or an acceptable version are not present return None while logging a clear
warning that includes transformers.__version__ and the missing symbol names
(refer to _make_rope_index_fn_qwen3_omni,
Qwen3OmniMoeThinkerForConditionalGeneration, get_llm_pos_ids_for_vision,
get_rope_index, SimpleNamespace, MethodType).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 268e87a4-974a-495a-a822-01c74c77e46e
📒 Files selected for processing (10)
scripts/launch_vllm.pyscripts/prepare_data.pyscripts/train.pysrc/speculators/data_generation/preprocessing.pysrc/speculators/data_generation/vllm_client.pysrc/speculators/models/eagle3/core.pysrc/speculators/models/eagle3/data.pysrc/speculators/models/eagle3/metrics.pysrc/speculators/train/data.pysrc/speculators/train/vocab_mapping.py
| if has_mrope: | ||
| position_ids = [] | ||
| for sample in batch: # type: ignore[assignment] | ||
| pos = sample["position_ids"] | ||
| if pos.ndim == 1: | ||
| pos = pos.unsqueeze(0).expand(3, -1) | ||
| position_ids.append(pos) | ||
| collated_positions = torch.cat(position_ids, dim=-1) | ||
| collated_data["position_ids"] = pad_last_dim_to_length( | ||
| collated_positions, max_len | ||
| ).unsqueeze(1) | ||
| else: | ||
| collated_positions = torch.cat([b["position_ids"] for b in batch], dim=0) # type: ignore[index] | ||
| collated_data["position_ids"] = slice_and_pad_to_length( | ||
| collated_positions, max_len | ||
| ).unsqueeze(0) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for how position_ids is consumed in the model forward pass
rg -n --type=py 'position_ids' src/speculators/models/eagle3/ -A3 -B3Repository: vllm-project/speculators
Length of output: 6804
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Show collate logic area in src/speculators/train/data.py around the reported lines
sed -n '740,860p' src/speculators/train/data.py
# Inspect Eagle3 data preprocessing that slices position_ids
sed -n '1,120p' src/speculators/models/eagle3/data.py
# Inspect Eagle3 core around rotary_emb usage and any position_ids shape handling
# (rg output indicated rotary_emb call around lines ~321 and position_ids comments around ~242-264)
sed -n '220,420p' src/speculators/models/eagle3/core.pyRepository: vllm-project/speculators
Length of output: 13575
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find rotary_emb implementation in eagle3/core.py (or imported modules)
rg -n --type=py "def .*rotary_emb|rotary_emb" src/speculators/models/eagle3/core.py
# Also search for MRoPE/has_mrope/3-row position handling terms across eagle3 model code
rg -n --type=py "MRoPE|mrope|3[, ]|position_ids\.ndim" src/speculators/models/eagle3 -SRepository: vllm-project/speculators
Length of output: 4697
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate where create_collate_fn (or collate function) is defined to confirm exact lines and has_mrope condition
rg -n --type=py "def create_collate_fn|has_mrope|pad_last_dim_to_length|slice_and_pad_to_length" src/speculators/train/data.py
sed -n '760,835p' src/speculators/train/data.pyRepository: vllm-project/speculators
Length of output: 3559
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "==== src/speculators/train/data.py (collate section) ===="
sed -n '740,860p' src/speculators/train/data.py
echo
echo "==== src/speculators/models/eagle3/data.py (position_ids slicing) ===="
sed -n '1,120p' src/speculators/models/eagle3/data.py
echo
echo "==== src/speculators/models/eagle3/core.py (rotary_emb + position_ids handling) ===="
sed -n '200,420p' src/speculators/models/eagle3/core.py
echo
echo "==== Search rotary_emb and MRoPE/position_ids.ndim usage in eagle3 ===="
rg -n --type=py "rotary_emb|MRoPE|mrope|position_ids\.ndim" src/speculators/models/eagle3 -SRepository: vllm-project/speculators
Length of output: 16712
Fix/justify the MRoPE position_ids collation shape difference ([3, 1, max_len] vs [1, max_len])
src/speculators/train/data.py:create_collate_fn produces rank-3 position_ids for MRoPE (pos.ndim==1 → expand to [3, seq_len], then pad_last_dim_to_length(...).unsqueeze(1) → [3, 1, max_len]) but rank-2 for non-MRoPE ([1, max_len]).
Eagle3 training code expects position_ids shapes it can slice in src/speculators/models/eagle3/data.py:shift_batch (it only special-cases position_ids.ndim == 2; otherwise it slices position_ids[1:]). If shift_batch ever receives the collated MRoPE tensor, it will take the wrong slice path and corrupt position_ids.
Make position_ids shape consistent across both branches (or add explicit handling for position_ids.ndim == 3 where it’s consumed).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/speculators/train/data.py` around lines 803 - 818, The MRoPE branch in
create_collate_fn is producing a rank-3 position_ids tensor (due to expanding
pos with expand(3, -1) and unsqueezing) while the non-MRoPE branch produces
rank-2, causing downstream mis-slicing in shift_batch; fix create_collate_fn so
both branches produce the same shape (match the non-MRoPE output). Specifically,
in the has_mrope branch (the loop over batch building position_ids) stop
expanding 1-D pos to (3, seq_len) and instead ensure you concatenate along dim=0
to produce a rank-2 tensor, then call pad_last_dim_to_length (not pad to create
an extra leading dim) and finally unsqueeze(0) just like the else branch; keep
references to position_ids, pad_last_dim_to_length, slice_and_pad_to_length, and
create_collate_fn so shift_batch in eagle3/data.py sees a consistent shape.
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
Signed-off-by: PatchouliTaisa <patchychen@tencent.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
scripts/data_generation_offline.py (1)
243-265:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftFix multimodal offline generation to use vLLM Chat Completions (vision encoder)
scripts/data_generation_offline.pycallsgenerate_hidden_states_async(...)withoutuse_chat_completions, so it always takes the Completions path (even whenbuild_client_itemprovides multimodal payloads).generate_hidden_states_asyncexplicitly notes that the vision encoder path is only used whenuse_chat_completions=Trueandclient_item["messages"]is present, and the online implementation insrc/speculators/train/data.pymirrors this for multimodal rows.Update the worker to set
use_chat_completionslike the online code (multimodal && item.get("messages") is not None), pass it togenerate_hidden_states_async, and handle the tuple return(path, server_token_ids)—usingserver_token_idsforcheck_safetensors_fileinstead ofitem["input_ids"].🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@scripts/data_generation_offline.py` around lines 243 - 265, The generation call currently always uses the Completions path; set use_chat_completions = multimodal and item.get("messages") is not None and pass it into generate_hidden_states_async (i.e., await generate_hidden_states_async(..., use_chat_completions=use_chat_completions)); unpack the tuple return as hidden_states_path, server_token_ids = await generate_hidden_states_async(...) and then when validate_outputs is true call check_safetensors_file with server_token_ids (not item["input_ids"]) while keeping the existing lock/writes (wait_for_lock_async, shutil.move to target_hidden_states_path) logic intact.
🧹 Nitpick comments (2)
src/speculators/models/eagle3/rotary_partial.py (1)
72-92: 💤 Low valueConsider documenting concurrent-call behavior.
The function modifies global module state (
modeling_llama.apply_rotary_pos_embandmodeling_qwen3.apply_rotary_pos_emb) without synchronization. If called concurrently from multiple threads during model initialization, race conditions could occur between the_INSTALLEDcheck (line 78) and the assignments (line 90).If concurrent installation is possible in your deployment environment, consider adding a lock or documenting that callers must serialize access.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/speculators/models/eagle3/rotary_partial.py` around lines 72 - 92, install_partial_neox_rotary is racy: the check/use of the module-global _INSTALLED and the subsequent reassignment of modeling_llama.apply_rotary_pos_emb and modeling_qwen3.apply_rotary_pos_emb can race if called concurrently; protect the install/uninstall critical section with a module-level lock (e.g., threading.Lock) so the check of _INSTALLED, caching of module._speculators_original_apply_rotary_pos_emb, and assignment of module.apply_rotary_pos_emb are atomic, or alternatively document that callers must serialize calls — modify install_partial_neox_rotary to acquire the lock at the start and release after setting _INSTALLED and all module assignments (and do similarly in any uninstall path if present).src/speculators/data_generation/preprocessing.py (1)
879-892: 💤 Low value
assistant_patternmutation affects subsequent samples in batch.When
assistant_patternisNoneand the first multimodal sample triggers_detect_assistant_pattern, the pattern is reassigned in place. This side effect persists for all subsequent samples in the same batch, which is likely intentional. However, this mutation happens inside the per-sample loop and the pattern detection could fail for some edge-case samples. Consider detecting the pattern once before the loop or documenting this intentional mutation.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/speculators/data_generation/preprocessing.py` around lines 879 - 892, The code mutates assistant_pattern inside the per-sample loop by calling _detect_assistant_pattern(processor) on the first multimodal sample, causing the detected pattern to persist across subsequent samples; to fix, move the detection out of the per-sample loop so assistant_pattern is computed once (e.g., call _detect_assistant_pattern(processor) before iterating samples when assistant_pattern is None) and then use that value inside the loop for _loss_mask_from_assistant_token_spans and _loss_mask_from_ids_fallback, or if you prefer to keep the in-loop behavior, add a clear comment on assistant_pattern, _detect_assistant_pattern, and its intentional mutation to document the side effect.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@scripts/data_generation_offline.py`:
- Around line 243-265: The generation call currently always uses the Completions
path; set use_chat_completions = multimodal and item.get("messages") is not None
and pass it into generate_hidden_states_async (i.e., await
generate_hidden_states_async(..., use_chat_completions=use_chat_completions));
unpack the tuple return as hidden_states_path, server_token_ids = await
generate_hidden_states_async(...) and then when validate_outputs is true call
check_safetensors_file with server_token_ids (not item["input_ids"]) while
keeping the existing lock/writes (wait_for_lock_async, shutil.move to
target_hidden_states_path) logic intact.
---
Nitpick comments:
In `@src/speculators/data_generation/preprocessing.py`:
- Around line 879-892: The code mutates assistant_pattern inside the per-sample
loop by calling _detect_assistant_pattern(processor) on the first multimodal
sample, causing the detected pattern to persist across subsequent samples; to
fix, move the detection out of the per-sample loop so assistant_pattern is
computed once (e.g., call _detect_assistant_pattern(processor) before iterating
samples when assistant_pattern is None) and then use that value inside the loop
for _loss_mask_from_assistant_token_spans and _loss_mask_from_ids_fallback, or
if you prefer to keep the in-loop behavior, add a clear comment on
assistant_pattern, _detect_assistant_pattern, and its intentional mutation to
document the side effect.
In `@src/speculators/models/eagle3/rotary_partial.py`:
- Around line 72-92: install_partial_neox_rotary is racy: the check/use of the
module-global _INSTALLED and the subsequent reassignment of
modeling_llama.apply_rotary_pos_emb and modeling_qwen3.apply_rotary_pos_emb can
race if called concurrently; protect the install/uninstall critical section with
a module-level lock (e.g., threading.Lock) so the check of _INSTALLED, caching
of module._speculators_original_apply_rotary_pos_emb, and assignment of
module.apply_rotary_pos_emb are atomic, or alternatively document that callers
must serialize calls — modify install_partial_neox_rotary to acquire the lock at
the start and release after setting _INSTALLED and all module assignments (and
do similarly in any uninstall path if present).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c307ba3e-c7d0-4ce9-9eff-b2d6071d5c6f
📒 Files selected for processing (9)
scripts/data_generation_offline.pyscripts/launch_vllm.pyscripts/train.pysrc/speculators/data_generation/preprocessing.pysrc/speculators/data_generation/vllm_client.pysrc/speculators/models/eagle3/core.pysrc/speculators/models/eagle3/rotary_partial.pysrc/speculators/train/data.pytests/unit/models/test_eagle3_rotary_partial.py
🚧 Files skipped from review as they are similar to previous changes (4)
- src/speculators/models/eagle3/core.py
- scripts/train.py
- scripts/launch_vllm.py
- src/speculators/train/data.py
PR Description: Qwen3.6 and Multimodal Verifiers Support
Purpose
Add end-to-end training support for Qwen3.6 (Hybrid Attention, MoE) and other multimodal verifiers (Qwen3-Omni, Qwen3.5-MoE) to the speculators training pipeline. This includes proper handling of MRoPE position encoding, partial rotary factors, multimodal data preprocessing with sidecar tensors, and a fix for the token-id mismatch that caused 100% sample rejection during online hidden-states generation.
Description
Qwen3.6 Hybrid Attention & RoPE Training (
eagle3/core.py,scripts/train.py)_select_rotary_emb_classandPartialMRoPEto supportpartial_rotary_factor < 1in MRoPE-aware draft models.--draft-mrope-full-head-hack(default on): rescalesmrope_sectionby1/partial_rotary_factorand pinspartial_rotary_factor=1.0to ensure bit-equivalent rotation between HF'srotate_half(training) and vLLM's neox-partial rotation (inference).--draft-rope-scaling,--draft-rope-theta,--draft-max-position-embeddingsfor explicit RoPE override.--draft-intermediate-size,--draft-num-attention-heads,--draft-num-key-value-heads,--draft-head-dimfor MoE drafter geometry override.unwrap_verifier_text_configfor nested multimodal configs (thinker_config.text_config).Multimodal Data Pipeline (
preprocessing.py,train/data.py)_save_multimodal_sidecarwritesimage_grid_thw,video_grid_thw, etc. to.safetensorsfiles._build_multimodal_loss_mask: zeroes loss at placeholder token positions (image/video/audio pad tokens)._loss_mask_from_assistant_token_spansand placeholder-aware_loss_mask_from_ids_fallbackfor accurate multimodal assistant mask alignment.ArrowDatasetnow calls.with_format(None)afterload_from_diskto expose multimodal metadata columns hidden byprepare_data.py's persistedset_format._make_rope_index_fnsupporting both Qwen3-Omni (nestedthinker_config) and Qwen3.5/3.6 MoE (flatvision_config+text_config) config layouts for 3D MRoPEposition_idsgeneration.collate_fnsupports mixed 1D/3Dposition_idsbatching.Online Hidden-States Generation Fix (
vllm_client.py,train/data.py)chat.completions.create(messages=...)re-renders the chat template server-side, introducing an off-by-one newline token at vision-placeholder boundaries vs.prepare_data.py's HF processor tokenization. This causedextract_output's strict token-id check to reject 100% of samples →loss=0.Completionspath — send pre-tokenizedinput_ids+multi_modal_data(extracted media URLs) viaextra_body, bypassing server-side template rendering entirely.messagesretained only as opt-in fallback._collect_mm_payload_from_messagesto flatten media URLs from persisted messages into vLLM'smulti_modal_dataformat.vLLM Launch Script (
scripts/launch_vllm.py)--layer-idsalias for--target-layer-ids.thinker_configunwrap for Qwen3-Omni/Qwen3.6 multimodal verifiers.resolve_layer_idsseparating training target layer ids from vLLM extraction layer ids.hf_config_overridesto fixget_hf_text_config()failures with nested multimodal configs.Misc
EAGLE3_LOSS_CE_WEIGHTenv var for optional KL+CE mixed loss (eagle3/metrics.py).vocab_mapping.py: supportthinker_config.text_config.vocab_sizeunwrap.eagle3/data.py:shift_batchhandles 2D[3, seq_len]MRoPEposition_ids.Related Issue
N/A — Internal support for Qwen3.6-35B-A3B Eagle3/DFlash drafter training.
Tests
python3 -m py_compilepasses on all 10 modified files.prepare_data.pysuccessfully preprocesses multimodal (image+text) dataset with sidecar output.train.pywith--on-missing generatesuccessfully generates hidden states via token-idCompletionspath (no more token-id mismatch rejection).--draft-mrope-full-head-hackproduces non-zero loss and expected acceptance metrics.Checklist
Summary by CodeRabbit
New Features
Bug Fixes
Chores
--target-layer-idsto--layer-ids