Skip to content

[Spec] Cleanup idle stub and shape-check patterns#24881

Merged
hnyls2002 merged 3 commits into
mainfrom
lsyin/spec-misc
May 10, 2026
Merged

[Spec] Cleanup idle stub and shape-check patterns#24881
hnyls2002 merged 3 commits into
mainfrom
lsyin/spec-misc

Conversation

@hnyls2002
Copy link
Copy Markdown
Collaborator

@hnyls2002 hnyls2002 commented May 10, 2026

  • Use create_idle_input (zero-length tensors) for the all-finished idle stub in V1 EAGLE and multi-layer EAGLE workers, instead of bare EagleDraftInput() with None fields — robust to next-iter merge_batch / filter_batch
  • Unify input_ids.numel() -> input_ids.shape[0] across spec workers (4 sites)
  • Document EagleDraftInput batch-uniform field invariant on topk_p / topk_index / hidden_states / bonus_tokens

Follows up on #24859.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the handling of idle draft inputs in EagleWorker and MultiLayerEagleWorker by replacing manual EagleDraftInput instantiation with a call to _draft_preprocess_idle. A potential issue was identified in MultiLayerEagleWorker where the inherited _draft_preprocess_idle method might use an incorrect hidden size attribute, which could lead to an AttributeError or shape mismatch during subsequent batch operations.

# Install an idle EagleDraftInput so next iter's scheduler
# ops (merge_batch / filter_batch) see well-typed empty
# tensors instead of None.
self._draft_preprocess_idle(batch)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _draft_preprocess_idle method in MultiLayerEagleWorker (defined at line 363) delegates to EAGLEWorker._draft_preprocess_idle, which uses self.model_config.spec_hidden_size. However, MultiLayerEagleWorker appears to use self.model_config.hidden_size for its hidden states (as seen in line 683). This could lead to an AttributeError if spec_hidden_size is missing from the MTP model config, or a shape mismatch crash during merge_batch in the next iteration if the sizes differ. It would be safer to implement a local version of _draft_preprocess_idle that uses the correct hidden size logic for MTP models.

@hnyls2002
Copy link
Copy Markdown
Collaborator Author

/rerun-test test_eagle_infer_a.py test_eagle_infer_b.py test_eagle_constrained_decoding.py test_standalone_speculative_decoding.py

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 10, 2026

🚀 1-gpu-h100 (4 tests): ✅ View workflow run

cd test/ && python3 registered/spec/eagle/test_eagle_infer_a.py
cd test/ && python3 registered/spec/eagle/test_eagle_infer_b.py
cd test/ && python3 registered/spec/eagle/test_eagle_constrained_decoding.py
cd test/ && python3 registered/spec/test_standalone_speculative_decoding.py

@hnyls2002
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@hnyls2002 hnyls2002 changed the title misc cleanup [Spec] Cleanup idle stub and shape-check patterns May 10, 2026
@hnyls2002 hnyls2002 merged commit 8cc16c9 into main May 10, 2026
124 of 184 checks passed
@hnyls2002 hnyls2002 deleted the lsyin/spec-misc branch May 10, 2026 09:39
ltcs11 added a commit to ltcs11/sglang that referenced this pull request May 11, 2026
* main: (87 commits)
  [Fix] Disable FlashInfer allreduce fusion under deterministic inference (sgl-project#24629)
  fix: STANDALONE spec-decode hidden-size mismatch crash (sgl-project#24217)
  Followup fix for Custom AR V2 in non NVL scenarios (sgl-project#24742)
  Fix reduce_scatterv producer contract for SUM_LEN (sgl-project#24785)
  [NPU]Documentation update for communications quantization feature (sgl-project#24668)
  [Session R3] Add routed_experts_start_len for absolute routing slice control (sgl-project#24851)
  [Model] Add MiniCPM-V 4.6 support (sgl-project#24855)
  Support Intern-S2-Preview (sgl-project#24875)
  [PD] Unify dsv4 dispatch with swa (sgl-project#24888)
  Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head (sgl-project#24775)
  Fix PD bootstrap failure handling (sgl-project#24772)
  [Spec] Cleanup idle stub and shape-check patterns (sgl-project#24881)
  [Bug] Add dsv4 state_type branch to mooncake disaggregation (sgl-project#24878)
  [Spec V1] Split draft-extend phase from `EagleDraftInput` into new `EagleDraftExtendInput` (sgl-project#24859)
  [Gemma4] Optimize Gemm4 with fused Q/K/V RMSNorm + per-expert FP8 ckpt loader (sgl-project#24696)
  [spec decoding] support kimi-k2.5-eagle3-mla (sgl-project#24826)
  [SPEC V2] fix: skip stale state updates in spec-v2 overlap (sgl-project#23456)
  [RL] Call torch.cuda.empty_cache() for `in-place` pause mode to avoid OOM (sgl-project#24854)
  [diffusion] CI: add cache-dit CI tests (sgl-project#19213)
  [Utils] Make request dump robust to unpicklable server_args and large meta_info (sgl-project#24767)
  ...

# Conflicts:
#	python/sglang/srt/utils/common.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant