Skip to content

[Model] Extract GatedDeltaNetAttention into shared layer for Qwen3Next and Qwen3.5#37975

Merged
jikunshang merged 12 commits intovllm-project:mainfrom
wxsIcey:refactor-gdn
Mar 27, 2026
Merged

[Model] Extract GatedDeltaNetAttention into shared layer for Qwen3Next and Qwen3.5#37975
jikunshang merged 12 commits intovllm-project:mainfrom
wxsIcey:refactor-gdn

Conversation

@wxsIcey
Copy link
Copy Markdown
Contributor

@wxsIcey wxsIcey commented Mar 24, 2026

Purpose

Move the GDN (Gated Delta Net) layer implementation from qwen3_next.py into a dedicated gdn_linear_attn.py, and unify Qwen3Next and Qwen3.5 under a single GatedDeltaNetAttention class.

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the qwen Related to Qwen models label Mar 24, 2026
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Mar 24, 2026

Since key-value in-contiguous are not supported in xpu and npu, the operators of the GatedDeltaNetAttention layer must be rewritten in xpu and npu:
#33657
vllm-project/vllm-ascend#6640

For in-tree platform dispatch, I currently do not have a good solution. For out-of-tree platform dispatch, the PluggableLayerapproach can be used. Therefore, this refactoring is proposed.

cc @ZJY0516

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 24, 2026

cc @vadiklyutiy @jikunshang @tdoublep

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

This is a very important and valuable change! The GDN implementation was quite messy—thank you very much for your contribution. Will take a look later today.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

/gemini review

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Mar 24, 2026

/gemini review

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 24, 2026

@codex review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The Gated Delta Net (GDN) attention implementation, including its custom operations and Triton kernels, has been refactored into a new dedicated file, gdn_linear_attn.py. This new GatedDeltaNetAttention class now serves as a unified implementation for both Qwen3-Next and Qwen3.5 models, replacing the previously separate GDN classes and handling model-specific configurations like GQA interleaved layouts and LoRA compatibility through parameters. A critical issue was identified in the fix_query_key_value_ordering method, where the new_tensor_shape_ba is incorrectly derived from mixed_qkvz.size() instead of mixed_ba.size(), which could lead to a runtime error if the number of tokens differs between these tensors.

@wangxiyuan
Copy link
Copy Markdown
Contributor

yes, we need it for long time.

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 24, 2026

please test qwen 3.5, qwen 3 next and lora

@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Mar 24, 2026

please test qwen 3.5, qwen 3 next and lora

OK, I will add it.

@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Mar 24, 2026

please test qwen 3.5, qwen 3 next and lora

OK, I will add it.

I tested this refactor on XPU platform, Qwen3.5-9B shows accuracy issue. Can you check?

@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Mar 24, 2026

please test qwen 3.5, qwen 3 next and lora

OK, I will add it.

I tested this refactor on XPU platform, Qwen3.5-9B shows accuracy issue. Can you check?

I just performed a simple test on the A100, and the output is normal. I don't have an XPU machine for testing. What are your test cases and outputs? Or did you perform an accuracy test?

import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import LLM, SamplingParams

def main():
    prompts = [
        "The future of AI is",
        "Who is the President of the United States?",
        ]
    sampling_params = SamplingParams(temperature=0.8)

    llm = LLM(
        model="/shared/models/modelscope/models/models/Qwen/Qwen3.5-9B",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.9,
        max_model_len=8092,
    )

    outputs = llm.generate(prompts, sampling_params=sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Prompt: 'The future of AI is', Generated text: ' here\n\nAI is an immensely powerful tool, but it needs to be managed.'
Prompt: 'Who is the President of the United States?', Generated text: '\n\n**Joe Biden** is the President of the United States. He assumed office'

Copy link
Copy Markdown
Collaborator

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for refactoring. just some minor comments from my side.

use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)

def forward_native(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideally, forward_native should be a torch-native impl, so every platform could leverage. using triton here cpu platform will throw error. I am ok to keep this. just some minor concern. cc @bigPYJ1151

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a torch-native impl, just like no torch-native flash attn in vllm

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree we don't need it here. it's just about naming, maybe we should rename to forward_triton to avoid confusion.
My understanding is forward_native in CustomOp should be a torch-native impl. https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/custom_op.py#L138-L144

Copy link
Copy Markdown
Contributor Author

@wxsIcey wxsIcey Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that forward_native should be a torch-native implementation,so using triton here is not reasonable. However, CustomOp is a platform-specific forward dispatch ( https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/custom_op.py#L196-L207 ). forward_triton does not have this functionality. I think the best practice is the vLLM IR proposed by Luka ( #32358 ), we can define triton as an IR kernel and specify its platform-wide usage.

So we can wait for the IR PR to be merged and the code to be refactored. Is that acceptable?

cc @ProExpertProg

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can do this later — it doesn't necessarily have to be done in this PR.

@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Mar 25, 2026

please test qwen 3.5, qwen 3 next and lora

OK, I will add it.

I tested this refactor on XPU platform, Qwen3.5-9B shows accuracy issue. Can you check?

I just performed a simple test on the A100, and the output is normal. I don't have an XPU machine for testing. What are your test cases and outputs? Or did you perform an accuracy test?

import os
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
from vllm import LLM, SamplingParams

def main():
    prompts = [
        "The future of AI is",
        "Who is the President of the United States?",
        ]
    sampling_params = SamplingParams(temperature=0.8)

    llm = LLM(
        model="/shared/models/modelscope/models/models/Qwen/Qwen3.5-9B",
        tensor_parallel_size=1,
        gpu_memory_utilization=0.9,
        max_model_len=8092,
    )

    outputs = llm.generate(prompts, sampling_params=sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
Prompt: 'The future of AI is', Generated text: ' here\n\nAI is an immensely powerful tool, but it needs to be managed.'
Prompt: 'Who is the President of the United States?', Generated text: '\n\n**Joe Biden** is the President of the United States. He assumed office'

Thanks for check. Then maybe our platform specific issue. Let me take a look further.

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wxsIcey.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2026
wxsIcey added 2 commits March 26, 2026 08:30
…t and Qwen3.5

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
wxsIcey added 6 commits March 26, 2026 08:30
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
@jikunshang jikunshang added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 26, 2026
@mergify mergify bot removed the needs-rebase label Mar 26, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 26, 2026

Hi @wxsIcey, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@jikunshang
Copy link
Copy Markdown
Collaborator

I triggered LM Eval Qwen3.5 Models (B200) case in CI
https://buildkite.com/vllm/ci/builds/58211#019d2948-6733-4b4a-846f-92163a24ad03

wxsIcey added 2 commits March 26, 2026 13:02
Signed-off-by: Icey <1790571317@qq.com>
@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Mar 27, 2026

I triggered LM Eval Qwen3.5 Models (B200) case in CI https://buildkite.com/vllm/ci/builds/58211#019d2948-6733-4b4a-846f-92163a24ad03

Thank you. All tests have passed. Is it possible to get approval?

@yma11
Copy link
Copy Markdown
Contributor

yma11 commented Mar 27, 2026

@jikunshang PTAL again.

@jikunshang
Copy link
Copy Markdown
Collaborator

@claude review

Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things still need attention before this can merge: there is an unresolved reviewer question on gdn_linear_attn.py, pre-commit was failing as of March 26 (unclear if fixed), and the gdn_in_proj stub calls an undefined _forward_in_proj method (see inline comment).

Extended reasoning...

Overview

This PR extracts GatedDeltaNetAttention from qwen3_next.py into a new shared gdn_linear_attn.py and removes the Qwen3NextGatedDeltaNet/Qwen3_5GatedDeltaNet subclass hierarchy in favour of a single parameterised class. It touches ~1000 lines across three files, primarily a refactor but with meaningful behavioural changes.

Security Risks

No security-sensitive code paths are touched. Risk is limited to model correctness and inference performance.

Level of Scrutiny

The change is architecturally meaningful as the foundation for out-of-tree platform dispatch unifying two previously divergent implementations. LM Eval on B200 passed, but several reviewer concerns remain open. This warrants human approval rather than bot shadow-approval.

Other Factors

  1. Unresolved reviewer comment - The inline question on gdn_linear_attn.py (comment id 2981910661, "is this necessary?") is not marked resolved in the timeline.
  2. Pre-commit failure - Mergify reported a pre-commit failure on March 26; it is not confirmed whether this was subsequently fixed.
  3. gdn_in_proj dead code - The stub function at line 950 calls self._forward_in_proj which does not exist on GatedDeltaNetAttention. While currently unreachable dead code, it is a footgun for future platform-plugin authors (see inline comment for details).
  4. forward_native naming - ChunkGatedDeltaRule.forward_native actually invokes a Triton/FLA kernel, violating the CustomOp convention. The team agreed to defer this to a follow-up.

Comment on lines +950 to +961
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Custom op for the input projection.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self._forward_in_proj(hidden_states)


def gdn_attention_core(
mixed_qkv: torch.Tensor,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The gdn_in_proj function (line 950) calls self._forward_in_proj(hidden_states), but GatedDeltaNetAttention defines no _forward_in_proj method — this would raise AttributeError if invoked. The function is also never registered via direct_register_custom_op (unlike gdn_attention_core), making it unreachable dead code. The _qkvz_output_size helper (with docstring "for gdn_in_proj fake impl") is similarly orphaned and should be removed or completed.

Extended reasoning...

The gdn_in_proj function defined at lines 946–957 of vllm/model_executor/layers/mamba/gdn_linear_attn.py is an incomplete custom-op stub left over from the refactoring. It retrieves the layer from the forward context and then calls self._forward_in_proj(hidden_states), but GatedDeltaNetAttention (and none of its base classes — PluggableLayer, MambaBase) defines any _forward_in_proj method. A grep of the entire codebase confirms only one occurrence of _forward_in_proj: the call site at line 957 itself.

The companion _qkvz_output_size method at line 668 has a docstring that explicitly reads "for gdn_in_proj fake impl", confirming that the author intended this to become a full custom op (analogous to the working gdn_attention_core op), with a real implementation and a fake/shape-only implementation for torch.compile. That work was never completed: direct_register_custom_op is called only for gdn_attention_core, never for gdn_in_proj.

As things stand, gdn_in_proj is unreachable dead code. No code path in the codebase calls it directly, and it is not registered as a torch.ops.vllm.* custom op that could be dispatched to. So there is no runtime failure today.

However, the purpose of PluggableLayer is precisely to allow out-of-tree platform plugins (e.g., XPU, NPU) to register alternative forward implementations. If a plugin author discovers gdn_in_proj and tries to wire it up as a custom op for the input-projection step, the stub will crash with AttributeError: GatedDeltaNetAttention object has no attribute _forward_in_proj the first time it is called. This makes the code a footgun for future contributors and platform integrators.

Step-by-step proof:

  1. A platform plugin registers gdn_in_proj via direct_register_custom_op pointing at gdn_in_proj.
  2. The plugin arranges for the forward pass to call torch.ops.vllm.gdn_in_proj(hidden_states, qkvz_size, ba_size, self.prefix).
  3. gdn_in_proj runs, resolves self from the forward context, then executes self._forward_in_proj(hidden_states).
  4. Python raises AttributeError: GatedDeltaNetAttention object has no attribute _forward_in_proj.

Fix: Either (a) implement _forward_in_proj on GatedDeltaNetAttention and register gdn_in_proj with direct_register_custom_op plus a gdn_in_proj_fake shape function, or (b) remove gdn_in_proj and _qkvz_output_size entirely if the custom-op abstraction for input projection is not needed in this PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. I forgot to remove the unnecessary gdn_in_proj when resolving the conflict at #38152. It has been fixed.

wxsIcey added 2 commits March 27, 2026 02:23
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
@wxsIcey
Copy link
Copy Markdown
Contributor Author

wxsIcey commented Mar 27, 2026

@claude review

@jikunshang jikunshang merged commit a8eab8f into vllm-project:main Mar 27, 2026
61 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in Qwen3.5 Mar 27, 2026
yma11 pushed a commit to yma11/vllm that referenced this pull request Mar 27, 2026
…t and Qwen3.5 (vllm-project#37975)

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…t and Qwen3.5 (vllm-project#37975)

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…t and Qwen3.5 (vllm-project#37975)

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…t and Qwen3.5 (vllm-project#37975)

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: Icey <1790571317@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

6 participants