Add patch_qwen3_5 for triton ops fused_recurrent_gated_delta_rule#7109
Add patch_qwen3_5 for triton ops fused_recurrent_gated_delta_rule#7109wangxiyuan merged 7 commits intovllm-project:mainfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds a patch for Qwen3.5's GatedDeltaNet to enable it on Ascend hardware. The patch introduces a workaround for an operator limitation by using a Triton kernel for the decode path. My review identifies a critical bug where a variable could be used before initialization, potentially causing a crash. I've provided a code suggestion to fix this issue.
| if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None: | ||
| g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias) | ||
| if spec_sequence_masks is not None: | ||
| if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0: | ||
| g_spec = g | ||
| beta_spec = beta | ||
| g_non_spec = None | ||
| beta_non_spec = None | ||
| else: | ||
| g_spec = g.index_select(1, spec_token_indx) | ||
| beta_spec = beta.index_select(1, spec_token_indx) | ||
| g_non_spec = g.index_select(1, non_spec_token_indx) | ||
| beta_non_spec = beta.index_select(1, non_spec_token_indx) | ||
| else: | ||
| g_spec = None | ||
| beta_spec = None | ||
| g_non_spec = g | ||
| beta_non_spec = beta | ||
|
|
||
| # 2. Recurrent attention | ||
|
|
||
| # 2.1: Process the multi-query part | ||
| if spec_sequence_masks is not None: | ||
| core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( | ||
| q=query_spec, | ||
| k=key_spec, | ||
| v=value_spec, | ||
| g=g_spec, | ||
| beta=beta_spec, | ||
| initial_state=ssm_state, | ||
| inplace_final_state=True, | ||
| cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1], | ||
| ssm_state_indices=spec_state_indices_tensor, | ||
| num_accepted_tokens=num_accepted_tokens, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ) | ||
| else: | ||
| core_attn_out_spec, last_recurrent_state = None, None | ||
|
|
||
| # 2.2: Process the remaining part | ||
| if attn_metadata.num_prefills > 0: | ||
| initial_state = ssm_state[non_spec_state_indices_tensor].contiguous() | ||
| initial_state[~has_initial_state, ...] = 0 | ||
| ( | ||
| core_attn_out_non_spec, | ||
| last_recurrent_state, | ||
| ) = chunk_gated_delta_rule( | ||
| q=query_non_spec, | ||
| k=key_non_spec, | ||
| v=value_non_spec, | ||
| g=g_non_spec, | ||
| beta=beta_non_spec, | ||
| initial_state=initial_state, | ||
| output_final_state=True, | ||
| cu_seqlens=non_spec_query_start_loc, | ||
| head_first=False, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ) | ||
| # Init cache | ||
| ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype) | ||
| elif attn_metadata.num_decodes > 0: | ||
| core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule( | ||
| q=query_non_spec, | ||
| k=key_non_spec, | ||
| v=value_non_spec, | ||
| g=g_non_spec, | ||
| beta=beta_non_spec, | ||
| initial_state=ssm_state, | ||
| inplace_final_state=True, | ||
| cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1], | ||
| ssm_state_indices=non_spec_state_indices_tensor, | ||
| use_qk_l2norm_in_kernel=True, | ||
| ) | ||
| else: | ||
| core_attn_out_non_spec, last_recurrent_state = None, None | ||
|
|
||
| elif attn_metadata.num_decodes > 0: | ||
| core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update( | ||
| A_log=self.A_log.contiguous(), | ||
| dt_bias=self.dt_bias.contiguous(), | ||
| q=query_non_spec.contiguous(), | ||
| k=key_non_spec.contiguous(), | ||
| v=value_non_spec.contiguous(), | ||
| a=a.contiguous(), | ||
| b=b.contiguous(), | ||
| initial_state_source=ssm_state, | ||
| initial_state_indices=non_spec_state_indices_tensor, | ||
| cu_seqlens=non_spec_query_start_loc, | ||
| use_qk_l2norm_in_kernel=True, | ||
| softplus_beta=1.0, | ||
| softplus_threshold=20.0, | ||
| ) | ||
|
|
||
| # 3. Merge core attention output | ||
| if spec_sequence_masks is not None and core_attn_out_non_spec is not None: | ||
| merged_out = torch.empty( | ||
| (1, num_actual_tokens, *core_attn_out_spec.shape[2:]), | ||
| dtype=core_attn_out_non_spec.dtype, | ||
| device=core_attn_out_non_spec.device, | ||
| ) | ||
| merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec) | ||
| merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec) | ||
| if not enable_sp(): | ||
| core_attn_out[:num_actual_tokens] = merged_out.squeeze(0) | ||
| else: | ||
| core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens] | ||
| elif spec_sequence_masks is not None: | ||
| if not enable_sp(): | ||
| core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0) | ||
| else: | ||
| core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens] | ||
| else: | ||
| if not enable_sp(): | ||
| core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) | ||
| else: | ||
| core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens] |
There was a problem hiding this comment.
The variable core_attn_out_non_spec is not initialized on all code paths, which can lead to an UnboundLocalError. Specifically, if num_prefills == 0, spec_sequence_masks is None, and num_decodes == 0, the variable is never assigned a value, but it is accessed in the final else block of the merge logic. This is a critical issue that could cause a runtime crash.
core_attn_out_non_spec = None
core_attn_out_spec = None
if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta
# 2. Recurrent attention
# 2.1: Process the multi-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out_spec, last_recurrent_state = None, None
# 2.2: Process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
# Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out_non_spec, last_recurrent_state = None, None
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log.contiguous(),
dt_bias=self.dt_bias.contiguous(),
q=query_non_spec.contiguous(),
k=key_non_spec.contiguous(),
v=value_non_spec.contiguous(),
a=a.contiguous(),
b=b.contiguous(),
initial_state_source=ssm_state,
initial_state_indices=non_spec_state_indices_tensor,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
softplus_beta=1.0,
softplus_threshold=20.0,
)
# 3. Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
merged_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
if not enable_sp():
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens]
elif spec_sequence_masks is not None:
if not enable_sp():
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens]
elif core_attn_out_non_spec is not None:
if not enable_sp():
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens]Signed-off-by: pppeng <zepengliu912@qq.com>
0e7c5c2 to
6688359
Compare
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
Signed-off-by: pppeng <zepengliu912@qq.com>
Signed-off-by: pppeng <zepengliu912@qq.com>
Signed-off-by: pppeng <zepengliu912@qq.com>
Signed-off-by: pppeng <60355449+ppppeng@users.noreply.github.com>
Signed-off-by: pppeng <zepengliu912@qq.com>
| from tests.e2e.conftest import VllmRunner | ||
|
|
||
|
|
||
| def test_qwen3_5_27b_distributed_mp_tp4(): |
There was a problem hiding this comment.
plz add the test in
477c589 to
3999080
Compare
Signed-off-by: pppeng <zepengliu912@qq.com>
002dfa5 to
91c51f3
Compare
…to qwen3next_graph * 'main' of https://github.com/vllm-project/vllm-ascend: (88 commits) [main][bugfix] Fixed the problem of speculative decoding in FULL mode (vllm-project#7148) fixed fia pad logic in graph mode. (vllm-project#7144) [Doc] fix DSV3.1 PD configs (vllm-project#7187) refactor: add a check before layer_sharding logging (vllm-project#7186) [Build] Add support for Ascend950 chip (vllm-project#7151) Revert "[CI] fix skiped e2e test when upgrade vllm version (vllm-project#6654)" (vllm-project#7166) [MODELRUNNERV2]fix penality ops (vllm-project#7013) [Bugfix][LoRA] Fix the issue when enable LoRA + tp + fully_sharded_loras (vllm-project#6650) [KV Pool]get_num_new_matched_tokens return 0 if token length < block_size (vllm-project#7146) [CI] Build Image for v0.16.0rc1 (vllm-project#7155) [CI] Skip `test_mooncake_layerwise_connector.py` in `ut` (vllm-project#7147) [BugFix]Fix recomputed scheduler bug (vllm-project#7137) [Model] Support Minimax-m2.5 on NPU (vllm-project#7105) [P/D]Mooncake Layerwise Connector supports hybrid attention manager with multiple kvcache groups (vllm-project#7022) Add patch_qwen3_5 for triton ops fused_recurrent_gated_delta_rule (vllm-project#7109) [Doc][ReleaseNote] Add release notes for v0.16.0rc1 (vllm-project#7067) [Misc] Download on both hk and guiyang region (vllm-project#7129) [bugdix] The problem that the w4a8 weight fails to be loaded when the EP is not enabled is resolved. (vllm-project#7090) [eagle][cp] fix eagle_cp enable bug2 (vllm-project#7079) [CI]Upgrade niglty multi-node-tests max-parallel to 2 (vllm-project#7035) ...
…lm-project#7109) ### What this PR does / why we need it? The ops `torch_npu.npu_recurrent_gated_delta_rule` currently does not support `ssm_state` inputs in float32 format, we temporarily retain the _forward_core implementation with triton for Qwen3_5 --------- Signed-off-by: pppeng <zepengliu912@qq.com> Signed-off-by: pppeng <60355449+ppppeng@users.noreply.github.com>
What this PR does / why we need it?
The ops
torch_npu.npu_recurrent_gated_delta_rulecurrently does not supportssm_stateinputs in float32 format,we temporarily retain the _forward_core implementation with triton for Qwen3_5.
Does this PR introduce any user-facing change?
How was this patch tested?