Skip to content

[Feature] Add DFlash Speculative Decoding Support for Qwen3-VL Model#18387

Open
EanWang211123 wants to merge 54 commits intosgl-project:mainfrom
EanWang211123:vlm-dflash-test
Open

[Feature] Add DFlash Speculative Decoding Support for Qwen3-VL Model#18387
EanWang211123 wants to merge 54 commits intosgl-project:mainfrom
EanWang211123:vlm-dflash-test

Conversation

@EanWang211123
Copy link
Copy Markdown

@EanWang211123 EanWang211123 commented Feb 7, 2026

Motivation

This PR adds DFlash speculative decoding support for the Qwen3-VL model. It depends on #16818.

DFlash Speculative Decoding:

  1. The DFlash draft model only requires the target model's hidden states as input. This allows a DFlash draft model adapted for a corresponding target base model to be used with a multimodal version of that target model (e.g., Qwen3-VL-8B-Instruct + Qwen3-8B-Dflash-b16).
  2. During testing, due to DFlash's generalization capability, no special training on multimodal data was required to achieve an average acceptance length of over 2 tokens.
  3. SpecForge issue for VL model DFlash adaptation: [Feature] [RFC] DFlash Training Adaptation for Qwen VL Models SpecForge#461

Modifications

New Files

  • benchmark\dflash\bench_dflash_mmstar.py: MMStar benchmark, outputs throughput and acceptance length.

Changed Files

  • python\sglang\srt\models\qwen3_vl.py: Added set_dflash_layers_to_capture interface.

Key Features

Multimodal Adaptation:
Follows the standalone-style multimodal speculative decoding adaptation approach (e.g., Qwen3-8B-VL + Qwen3-0.6B), using the same MRoPE adaptation logic.

Restore global server_args after DFlash worker initialization to prevent SHM feature decoding failure:
When launching DFlash speculative decoding with Qwen3-VL (tp_size=2), the first image request triggers TypeError: object supporting the buffer API required. VLM running alone or with other speculative decoding methods works fine.

Root Cause

In single-node SGLang deployment, the tokenizer process transfers image feature tensors to scheduler via shared memory (SHM). The sender wraps data with ShmPointerMMData, and the receiver unwraps using unwrap_shm_features, which depends on the global server_args.skip_tokenizer_init to determine whether unwrapping is needed:

def unwrap_shm_features(obj):
    if ... or get_global_server_args().skip_tokenizer_init:
        return obj  # Skip unwrapping

When initializing DFlash's draft worker, it deepcopys server_args and sets skip_tokenizer_init=True (since text-only draft models don't require a tokenizer). During ModelRunner.__init__, the draft worker calls set_global_server_args_for_scheduler(draft_server_args), overwriting the global server_args with the draft version.
As a result: the tokenizer properly wraps features, but the scheduler skips unwrapping due to polluted global variables, passing the raw ShmPointerMMData object directly to hashlib.sha256(), causing a TypeError.
Other speculative decoding methods like EAGLE remain unaffected because they pass the original server_args directly (without deepcopy or modification), so global variables remain unchanged.

Fix

In dflash_worker.py, save the global server_args before creating the draft worker and restore it afterward:

saved_server_args = get_global_server_args()
self.draft_worker = TpModelWorker(server_args=draft_server_args, ...)
set_global_server_args_for_scheduler(saved_server_args)

Risk Assessment

  • Changes are isolated within dflash_worker.py and do not affect other speculative decoding methods
  • Global variables remain properly set during draft worker initialization, so the draft's own initialization logic is unaffected
  • The restored object is the same one previously set by the target worker, preserving all modified fields (e.g., use_mla_backend)
  • The entire operation completes before the scheduler event loop starts, eliminating concurrency risks

Tests

Environment: 4090D
Models: Qwen3-VL-8B-Instruct, Qwen3-8B-DFlash-b16
Test Dataset: MMStar

Test Commands

# Baseline (no speculative decoding)
SGLANG_DISABLE_CUDNN_CHECK=1 \
CUDA_VISIBLE_DEVICES=0,1 \
python -m sglang.launch_server \
--model-path /models/Qwen3-VL-8B-Instruct/ \
--tp-size 2 \
--dtype bfloat16 \
--mem-fraction-static 0.65 \
--cuda-graph-max-bs 32 --context-length 40960 --port 30000

# With DFlash speculative decoding
SGLANG_DISABLE_CUDNN_CHECK=1 \
CUDA_VISIBLE_DEVICES=0,1 \
python -m sglang.launch_server \
--model-path /models/Qwen3-VL-8B-Instruct \
--speculative-algorithm DFLASH \
--speculative-draft-model-path /models/Qwen3-8B-DFlash-b16 \
--tp-size 2 \
--dtype bfloat16 \
--mem-fraction-static 0.65 \
--cuda-graph-max-bs 32 --context-length 40960

# Run benchmark
python benchmark/dflash/bench_dflash_mmstar.py --port 30000 \
--dataset-path /datasets/mmstar \
--num-samples 10 --concurrency 1  \
--max-completion-tokens 2048 --temperature 0.0

Test Results

Concurrency = 1:

Metric DFlash Baseline
Throughput (tok/s) 45.48 23.60
Acceptance Length 2.8 N/A

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for DFlash speculative decoding, which is a significant feature. The changes are extensive, touching many parts of the system from server configuration and model execution to attention backends and model implementations. The implementation appears robust and well-integrated with the existing speculative decoding framework.

Key changes include:

  • A new DFlashWorker and associated data structures (DFlashDraftInput, DFlashVerifyInput) to manage the DFlash-specific workflow.
  • A new dflash.py model implementation for the draft model, which correctly omits embedding and LM head layers.
  • Modifications to attention backends (flashinfer, trtllm_mha) to support DFlash's requirements, including a critical correctness fix in the trtllm_mha backend.
  • Integration with CUDA graph capture for performance.
  • New benchmark scripts for validation.

The code is well-structured, and the changes are generally clear and well-commented. I have one suggestion for improving the exception handling in the server argument parsing logic to make it more robust. Overall, this is a high-quality contribution.

@EanWang211123 EanWang211123 changed the title [Feature] Add DFlash Speculative Decoding Support for Qwen-VL Model [Feature] Add DFlash Speculative Decoding Support for Qwen3-VL Model Mar 4, 2026
@EanWang211123 EanWang211123 deleted the vlm-dflash-test branch March 13, 2026 08:45
@EanWang211123 EanWang211123 restored the vlm-dflash-test branch March 13, 2026 08:56
@EanWang211123 EanWang211123 reopened this Mar 24, 2026
@EanWang211123 EanWang211123 requested a review from hzh0425 as a code owner March 24, 2026 09:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 npu

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants