Skip to content

[Feature] add DFlash Support#8118

Merged
wangxiyuan merged 1 commit intovllm-project:mainfrom
chenaoxuan:dflash-latest
Apr 16, 2026
Merged

[Feature] add DFlash Support#8118
wangxiyuan merged 1 commit intovllm-project:mainfrom
chenaoxuan:dflash-latest

Conversation

@chenaoxuan
Copy link
Copy Markdown
Contributor

@chenaoxuan chenaoxuan commented Apr 10, 2026

This PR is inherited from PR-7162 and supports the latest vllm-ascend main. The old version is closed.

Purpose

We first supported DFlash on Ascend-NPU and then maintained it.

DFlash ("DFlash: Block Diffusion for Flash Speculative Decoding") is a parallel speculative decoding algorithm that generates multiple candidate tokens at once through a diffusion process.

Main changes:

  • Corresponds to the official support of vllm merged PR-36847.
  • Add dflash proposer implementation on the basis of SpecDecodeBaseProposer.
  • Modify the attention backend and add bidirectional attention branch.
  • Modify model_runner_v1 to support calling the dflash module.

Quick Start

[!Attention!]
As of April 10, vllm-ascend is not compatible with vllm that supports DFlash.
Therefore, cherry-pick is required:
cd vllm
git checkout -b new-branch v0.19.0
git cherry-pick 494636b29d3b3a7b35020e4becb6c6995e200f9d

[Weights]
Use official DFlash weights.

[Config]
--speculative-config '{"num_speculative_tokens": 8, "method":"dflash","model":"weight_path","enforce_eager": true}'

Test Results

Acceptance rate

Verified with Sglang(GPU) and Vllm(GPU) version of Qwen3-8B-DFlash-b16 in GSM8K dataset.

T=0 Draft Tokens = 16, Max Tokens = 2048

Batch Size Framework Mean Acceptance Length
4 SGlang 6.07
4 vLLM 6.08
4 vLLM-Ascend 6.05
8 SGlang 6.07
8 vLLM 6.08
8 vLLM-Ascend 6.06
16 SGlang 6.08
16 vLLM 6.08
16 vLLM-Ascend 6.08
32 SGlang 6.08
32 vLLM 6.09
32 vLLM-Ascend 6.08

Performance

Qwen3-8B, DP1/TP1, constructing gsm8k dataset to repeat the input length to 3.5K/output length 1.5K, data num 400, batch_size 16, temperature 0

Method Graph Mode Spec Num Mean Acceptance Length TOPT(ms) Output Token Throughput(token/s)
Eagle3 FULL_DECODE_ONLY 3 2.81 16.4 943.60(baseline)
Eagle3 FULL_DECODE_ONLY 8 3.60 19.5 795.34(↓15.7%)
Dflash PIECEWISE 8 5.25 12.4 1248.93(↑32.4%)

Accuracy

Qwen3-8B, DP1/TP1, output length 3.5K, data num 300, batch_size 16, temperature 0

Method Graph Mode Spec Num Dataset Accuracy(%)
Eagle3 FULL_DECODE_ONLY 3 gsm8k 84.67
Dflash PIECEWISE 8 gsm8k 85.00

Next Plan

  • Support FULL_DECODE_ONLY
  • Support Qwen3.5
  • The NPU Triton multi-core is faulty. Currently, only use a single core to process all reqs, which needs to be improved.
  • Operator optimization: now the maximum number of TND layout's for the FIA operator is 16. Therefore, the maximum sepc_num is 15. Although this issue can be bypassed, the performance will be affected.
  • vLLM main: https://github.com/vllm-project/vllm/commit/v0.19.0

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for DFlash, a parallel speculative decoding algorithm, within the vllm-ascend framework. It involves significant modifications to the attention mechanism to support bidirectional attention, the addition of a dedicated DFlash proposer, and the integration of a new Triton kernel for efficient input processing. These changes enable vllm-ascend to leverage DFlash for improved inference performance on Ascend NPUs, aligning with recent upstream vLLM developments.

Highlights

  • DFlash Speculative Decoding Support: Added comprehensive support for DFlash, a parallel speculative decoding algorithm, to the vllm-ascend framework.
  • Attention Backend Modifications: Modified the attention mechanism to support bidirectional attention, which is crucial for DFlash's operation.
  • New Triton Kernel for Input Preparation: Introduced a new Triton kernel to efficiently prepare and expand inputs for DFlash speculative decoding.
  • DFlash Proposer Integration: Implemented a dedicated DFlash proposer and integrated it into the speculative decoding pipeline.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements DFlash speculative decoding for Ascend, introducing the AscendDflashProposer, a Triton kernel for input expansion, and patches for DFlashQwen3Model to enable fused KV precomputation. It also updates the attention mechanism to support non-causal sequences. Feedback was provided to refactor the attention forward pass to eliminate code duplication and improve maintainability.

Suggested PR Title:

[Attention][Feature] Implement DFlash speculative decoding support

Suggested PR Summary:

### What this PR does / why we need it?
This PR implements DFlash speculative decoding for Ascend, introducing the `AscendDflashProposer`, a Triton kernel for input expansion, and patches for `DFlashQwen3Model` to enable fused KV precomputation. It also updates the attention mechanism to support non-causal sequences. Feedback was provided to refactor the attention forward pass to eliminate code duplication and improve maintainability.

### Does this PR introduce _any_ user-facing change?
Yes, it adds 'dflash' as a speculative decoding method.

### How was this patch tested?
The changes were integrated into the speculative decoding framework.

Comment on lines +852 to +882
if not attn_metadata.causal:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=0,
)
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
query=query,
key=key,
value=value,
atten_mask=attn_metadata.attn_mask,
block_table=block_table,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=attn_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
num_key_value_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale=self.scale,
sparse_mode=3,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is significant code duplication between the if and else blocks. This makes the code harder to maintain, as changes to the arguments of torch_npu.npu_fused_infer_attention_score must be applied in two places, increasing the risk of introducing bugs.

To improve maintainability, you can refactor the common arguments into a dictionary.

            common_args = {
                "query": query,
                "key": key,
                "value": value,
                "block_table": block_table,
                "input_layout": "TND",
                "block_size": block_size,
                "actual_seq_lengths": attn_metadata.actual_seq_lengths_q,
                "actual_seq_lengths_kv": actual_seq_lengths_kv,
                "num_key_value_heads": self.num_kv_heads,
                "num_heads": self.num_heads,
                "scale": self.scale,
            }
            if not attn_metadata.causal:
                attn_output, _ = torch_npu.npu_fused_infer_attention_score(
                    **common_args,
                    sparse_mode=0,
                )
            else:
                attn_output, _ = torch_npu.npu_fused_infer_attention_score(
                    **common_args,
                    atten_mask=attn_metadata.attn_mask,
                    sparse_mode=3,
                )

@chenaoxuan chenaoxuan force-pushed the dflash-latest branch 5 times, most recently from b8ba62c to d59c789 Compare April 10, 2026 07:27
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@chenaoxuan chenaoxuan force-pushed the dflash-latest branch 10 times, most recently from 1a5e879 to 2b2eb49 Compare April 14, 2026 09:38
@chenaoxuan chenaoxuan force-pushed the dflash-latest branch 3 times, most recently from 32974e4 to bfa2d6f Compare April 14, 2026 10:00
@lilinsiman lilinsiman added ready read for review ready-for-test start test by label for PR labels Apr 14, 2026
Signed-off-by: chenaoxuan <cax1165@163.com>
@wangxiyuan wangxiyuan merged commit 36b1e04 into vllm-project:main Apr 16, 2026
52 checks passed
1kzk pushed a commit to 1kzk/vllm-ascend that referenced this pull request Apr 20, 2026
**This PR is inherited from PR-[7162
](vllm-project#7162) and supports
the latest vllm-ascend main. The old version is closed.**


### Purpose
**We first supported DFlash on Ascend-NPU and then maintained it.**

> DFlash ("[DFlash: Block Diffusion for Flash Speculative
Decoding](https://arxiv.org/abs/2602.06036)") is a parallel speculative
decoding algorithm that generates multiple candidate tokens at once
through a diffusion process.

Main changes:
- Corresponds to the official support of vllm merged
PR-[36847](vllm-project/vllm#36847).
- Add dflash proposer implementation on the basis of
SpecDecodeBaseProposer.
- Modify the attention backend and add bidirectional attention branch.
- Modify model_runner_v1 to support calling the dflash module.

### Quick Start
[!Attention!] 
As of April 10, vllm-ascend is not compatible with vllm that supports
DFlash.
Therefore, cherry-pick is required:
`cd vllm`
`git checkout -b new-branch v0.19.0`
`git cherry-pick dc14cbf0c06e8a124bdf0c03e8e267feef60887e`

[Weights]
Use official DFlash
[weights](https://huggingface.co/collections/z-lab/dflash).

[Config]
--speculative-config '{"num_speculative_tokens": 8,
"method":"dflash","model":"weight_path","enforce_eager": true}'

### Test Results
#### Acceptance rate
Verified with Sglang(GPU) and Vllm(GPU) version of Qwen3-8B-DFlash-b16
in GSM8K dataset.

T=0 Draft Tokens = 16, Max Tokens = 2048

| Batch Size | Framework | Mean Acceptance Length |
|-----|-----|-----|
| 4 | SGlang | 6.07 |
| 4 | vLLM | 6.08 |
| 4 | vLLM-Ascend | 6.05 |
| 8 | SGlang | 6.07 |
| 8 | vLLM | 6.08 |
| 8 | vLLM-Ascend | 6.06 |
| 16 | SGlang | 6.08 |
| 16 | vLLM | 6.08 |
| 16 | vLLM-Ascend | 6.08 |
| 32 | SGlang | 6.08 |
| 32 | vLLM | 6.09 |
| 32 | vLLM-Ascend | 6.08 |

#### Performance
Qwen3-8B, DP1/TP1, constructing gsm8k dataset to repeat the input length
to 3.5K/output length 1.5K, data num 400, batch_size 16, temperature 0
| Method | Graph Mode | Spec Num | Mean Acceptance Length | TOPT(ms) |
Output Token Throughput(token/s)|
|-----|-----|-----|-----|-----|-----|
| Eagle3 | FULL_DECODE_ONLY | 3 | 2.81 | 16.4 | 943.60(baseline) |
| Eagle3 | FULL_DECODE_ONLY | 8 | 3.60 | 19.5 | 795.34(↓15.7%) |
| Dflash | PIECEWISE  | 8 | 5.25 | 12.4 | 1248.93(↑32.4%) |

#### Accuracy
Qwen3-8B, DP1/TP1, output length 3.5K, data num 300, batch_size 16,
temperature 0
| Method | Graph Mode | Spec Num | Dataset |  Accuracy(%) |
|-----|-----|-----|-----|-----|
| Eagle3 | FULL_DECODE_ONLY | 3 | gsm8k | 84.67 |
| Dflash | PIECEWISE  | 8 | gsm8k | 85.00 |

### Next Plan
- Support FULL_DECODE_ONLY
- Support Qwen3.5
- The NPU Triton multi-core is faulty. Currently, only use a single core
to process all reqs, which needs to be improved.
- Operator optimization: now the maximum number of TND layout's for the
FIA operator is 16. Therefore, the maximum sepc_num is 15. Although this
issue can be bypassed, the performance will be affected.
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.19.0

Signed-off-by: chenaoxuan <cax1165@163.com>
Pz1116 pushed a commit to Pz1116/vllm-ascend that referenced this pull request Apr 20, 2026
**This PR is inherited from PR-[7162
](vllm-project#7162) and supports
the latest vllm-ascend main. The old version is closed.**


### Purpose
**We first supported DFlash on Ascend-NPU and then maintained it.**

> DFlash ("[DFlash: Block Diffusion for Flash Speculative
Decoding](https://arxiv.org/abs/2602.06036)") is a parallel speculative
decoding algorithm that generates multiple candidate tokens at once
through a diffusion process.

Main changes:
- Corresponds to the official support of vllm merged
PR-[36847](vllm-project/vllm#36847).
- Add dflash proposer implementation on the basis of
SpecDecodeBaseProposer.
- Modify the attention backend and add bidirectional attention branch.
- Modify model_runner_v1 to support calling the dflash module.

### Quick Start
[!Attention!] 
As of April 10, vllm-ascend is not compatible with vllm that supports
DFlash.
Therefore, cherry-pick is required:
`cd vllm`
`git checkout -b new-branch v0.19.0`
`git cherry-pick dc14cbf0c06e8a124bdf0c03e8e267feef60887e`

[Weights]
Use official DFlash
[weights](https://huggingface.co/collections/z-lab/dflash).

[Config]
--speculative-config '{"num_speculative_tokens": 8,
"method":"dflash","model":"weight_path","enforce_eager": true}'

### Test Results
#### Acceptance rate
Verified with Sglang(GPU) and Vllm(GPU) version of Qwen3-8B-DFlash-b16
in GSM8K dataset.

T=0 Draft Tokens = 16, Max Tokens = 2048

| Batch Size | Framework | Mean Acceptance Length |
|-----|-----|-----|
| 4 | SGlang | 6.07 |
| 4 | vLLM | 6.08 |
| 4 | vLLM-Ascend | 6.05 |
| 8 | SGlang | 6.07 |
| 8 | vLLM | 6.08 |
| 8 | vLLM-Ascend | 6.06 |
| 16 | SGlang | 6.08 |
| 16 | vLLM | 6.08 |
| 16 | vLLM-Ascend | 6.08 |
| 32 | SGlang | 6.08 |
| 32 | vLLM | 6.09 |
| 32 | vLLM-Ascend | 6.08 |

#### Performance
Qwen3-8B, DP1/TP1, constructing gsm8k dataset to repeat the input length
to 3.5K/output length 1.5K, data num 400, batch_size 16, temperature 0
| Method | Graph Mode | Spec Num | Mean Acceptance Length | TOPT(ms) |
Output Token Throughput(token/s)|
|-----|-----|-----|-----|-----|-----|
| Eagle3 | FULL_DECODE_ONLY | 3 | 2.81 | 16.4 | 943.60(baseline) |
| Eagle3 | FULL_DECODE_ONLY | 8 | 3.60 | 19.5 | 795.34(↓15.7%) |
| Dflash | PIECEWISE  | 8 | 5.25 | 12.4 | 1248.93(↑32.4%) |

#### Accuracy
Qwen3-8B, DP1/TP1, output length 3.5K, data num 300, batch_size 16,
temperature 0
| Method | Graph Mode | Spec Num | Dataset |  Accuracy(%) |
|-----|-----|-----|-----|-----|
| Eagle3 | FULL_DECODE_ONLY | 3 | gsm8k | 84.67 |
| Dflash | PIECEWISE  | 8 | gsm8k | 85.00 |

### Next Plan
- Support FULL_DECODE_ONLY
- Support Qwen3.5
- The NPU Triton multi-core is faulty. Currently, only use a single core
to process all reqs, which needs to be improved.
- Operator optimization: now the maximum number of TND layout's for the
FIA operator is 16. Therefore, the maximum sepc_num is 15. Although this
issue can be bypassed, the performance will be affected.
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.19.0

Signed-off-by: chenaoxuan <cax1165@163.com>
anning-2026 pushed a commit to anning-2026/vllm-ascend that referenced this pull request Apr 21, 2026
**This PR is inherited from PR-[7162
](vllm-project#7162) and supports
the latest vllm-ascend main. The old version is closed.**


### Purpose
**We first supported DFlash on Ascend-NPU and then maintained it.**

> DFlash ("[DFlash: Block Diffusion for Flash Speculative
Decoding](https://arxiv.org/abs/2602.06036)") is a parallel speculative
decoding algorithm that generates multiple candidate tokens at once
through a diffusion process.

Main changes:
- Corresponds to the official support of vllm merged
PR-[36847](vllm-project/vllm#36847).
- Add dflash proposer implementation on the basis of
SpecDecodeBaseProposer.
- Modify the attention backend and add bidirectional attention branch.
- Modify model_runner_v1 to support calling the dflash module.

### Quick Start
[!Attention!] 
As of April 10, vllm-ascend is not compatible with vllm that supports
DFlash.
Therefore, cherry-pick is required:
`cd vllm`
`git checkout -b new-branch v0.19.0`
`git cherry-pick dc14cbf0c06e8a124bdf0c03e8e267feef60887e`

[Weights]
Use official DFlash
[weights](https://huggingface.co/collections/z-lab/dflash).

[Config]
--speculative-config '{"num_speculative_tokens": 8,
"method":"dflash","model":"weight_path","enforce_eager": true}'

### Test Results
#### Acceptance rate
Verified with Sglang(GPU) and Vllm(GPU) version of Qwen3-8B-DFlash-b16
in GSM8K dataset.

T=0 Draft Tokens = 16, Max Tokens = 2048

| Batch Size | Framework | Mean Acceptance Length |
|-----|-----|-----|
| 4 | SGlang | 6.07 |
| 4 | vLLM | 6.08 |
| 4 | vLLM-Ascend | 6.05 |
| 8 | SGlang | 6.07 |
| 8 | vLLM | 6.08 |
| 8 | vLLM-Ascend | 6.06 |
| 16 | SGlang | 6.08 |
| 16 | vLLM | 6.08 |
| 16 | vLLM-Ascend | 6.08 |
| 32 | SGlang | 6.08 |
| 32 | vLLM | 6.09 |
| 32 | vLLM-Ascend | 6.08 |

#### Performance
Qwen3-8B, DP1/TP1, constructing gsm8k dataset to repeat the input length
to 3.5K/output length 1.5K, data num 400, batch_size 16, temperature 0
| Method | Graph Mode | Spec Num | Mean Acceptance Length | TOPT(ms) |
Output Token Throughput(token/s)|
|-----|-----|-----|-----|-----|-----|
| Eagle3 | FULL_DECODE_ONLY | 3 | 2.81 | 16.4 | 943.60(baseline) |
| Eagle3 | FULL_DECODE_ONLY | 8 | 3.60 | 19.5 | 795.34(↓15.7%) |
| Dflash | PIECEWISE  | 8 | 5.25 | 12.4 | 1248.93(↑32.4%) |

#### Accuracy
Qwen3-8B, DP1/TP1, output length 3.5K, data num 300, batch_size 16,
temperature 0
| Method | Graph Mode | Spec Num | Dataset |  Accuracy(%) |
|-----|-----|-----|-----|-----|
| Eagle3 | FULL_DECODE_ONLY | 3 | gsm8k | 84.67 |
| Dflash | PIECEWISE  | 8 | gsm8k | 85.00 |

### Next Plan
- Support FULL_DECODE_ONLY
- Support Qwen3.5
- The NPU Triton multi-core is faulty. Currently, only use a single core
to process all reqs, which needs to be improved.
- Operator optimization: now the maximum number of TND layout's for the
FIA operator is 16. Therefore, the maximum sepc_num is 15. Although this
issue can be bypassed, the performance will be affected.
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.19.0

Signed-off-by: chenaoxuan <cax1165@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants