-
-
Notifications
You must be signed in to change notification settings - Fork 18.9k
[XPU] Enable torch.compile for XPU GDN attention #39466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
b7bfd5d
ab86c66
b4f4aa4
6901b85
a144955
cca9139
1b76790
b202893
bd734a7
f3a85d8
369d238
ac14bd2
e7f647c
f55db61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -620,53 +620,20 @@ def forward_xpu( | |
| # ============================================================ | ||
| # Part 2: Core Attention | ||
| # ============================================================ | ||
| forward_context = get_forward_context() | ||
| attn_metadata: AttentionMetadata = forward_context.attn_metadata | ||
| core_attn_out = torch.zeros( | ||
| (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device, | ||
| ) | ||
| z = torch.empty_like(core_attn_out) | ||
| if attn_metadata is not None: | ||
| attn_metadata = attn_metadata[self.prefix] | ||
|
|
||
| # TODO: xpu does not support this param yet | ||
| spec_sequence_masks = attn_metadata.spec_sequence_masks | ||
| assert spec_sequence_masks is None | ||
|
|
||
| conv_weights = self.conv1d.weight.view( | ||
| self.conv1d.weight.size(0), self.conv1d.weight.size(2) | ||
| ) | ||
|
|
||
| conv_state = self.kv_cache[0] | ||
| ssm_state = self.kv_cache[1] | ||
|
|
||
| torch.ops._xpu_C.gdn_attention( | ||
| core_attn_out, | ||
| z, | ||
| projected_states_qkvz, | ||
| projected_states_ba, | ||
| self.num_k_heads, | ||
| self.num_v_heads, | ||
| self.head_k_dim, | ||
| self.head_v_dim, | ||
| conv_state=conv_state, | ||
| ssm_state=ssm_state, | ||
| conv_weights=conv_weights, | ||
| conv_bias=self.conv1d.bias, | ||
| activation=self.activation, | ||
| A_log=self.A_log, | ||
| dt_bias=self.dt_bias, | ||
| num_prefills=attn_metadata.num_prefills, | ||
| num_decodes=attn_metadata.num_decodes, | ||
| has_initial_state=attn_metadata.has_initial_state, | ||
| non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, | ||
| non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, | ||
| num_actual_tokens=attn_metadata.num_actual_tokens, | ||
| tp_size=self.tp_size, | ||
| reorder_input=not self.gqa_interleaved_layout, | ||
| ) | ||
| torch.ops.vllm.gdn_attention_core_xpu( | ||
| core_attn_out, | ||
| z, | ||
| projected_states_qkvz, | ||
| projected_states_ba, | ||
| self.prefix, | ||
| ) | ||
|
Comment on lines
+628
to
+634
|
||
|
|
||
| # ============================================================ | ||
| # Part 3: Output Projection | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When
attn_metadataisNone, this implementation returns without writing toz(andzis created withtorch.empty_likeinforward_xpu). That leaveszuninitialized but it is subsequently consumed by the output projection (self.norm(core_attn_out, z)), producing nondeterministic outputs/NaNs during compile/profile passes that run withattn_metadata=None. To keep behavior well-defined, initializez(e.g., zero-fill or another safe default) before returning in theattn_metadata is Nonepath, or allocatezas zeros inforward_xpufor the no-metadata case.