Skip to content

[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel#914

Open
jvlunteren wants to merge 54 commits intotorch-spyre:mainfrom
jvlunteren:pytorch_native_attention_v2
Open

[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel#914
jvlunteren wants to merge 54 commits intotorch-spyre:mainfrom
jvlunteren:pytorch_native_attention_v2

Conversation

@jvlunteren
Copy link
Copy Markdown
Collaborator

Description

This PR extends PR #853 by replacing the 2D transposed attention kernel with a 4D broadcast matmul kernel, eliminating per‑sequence and per‑chunk loops, GQA head duplication, and block‑diagonal masking.

Related Issues

Relates to #647

Test Plan

Same approach as in PR #853.

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

bohnstingl and others added 30 commits March 23, 2026 09:30
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Joe Runde <joe@joerun.de>
…Spyre

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
jvlunteren and others added 19 commits March 25, 2026 10:56
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <161835099+jvlunteren@users.noreply.github.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

Signed-off-by: Jan van Lunteren <161835099+jvlunteren@users.noreply.github.com>

Prepares tensors on CPU (reshape, stickify, build mask), transfers to
Spyre for the compiled matmul kernel, then transfers the result back.
# Q: [B, padQ, num_heads, D] -> [B, num_heads, padQ, D]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is B here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batch size (number of sequences)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shapes listed in the comments were originally based on shortened variable names to keep the comments brief and within the line-width limit. For clarity, I have now replaced these abbreviated names with the full variable names used in the code.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, but in vLLM we should never have a batch size dimension in that way? Everything should be "flat"?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query argument in the forward method in line 240 has a "flat" vLLM v1 shape [num_tokens, num_heads, head_size].

This gets converted in line 283 to [num_seqs, max_query_len, num_heads, head_size] in order to be able to use torch.matmul.

In line 310 the output is converted back into the "flat" vLLM v1 shape [num_actual_tokens, num_heads, head_size].

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Copy link
Copy Markdown
Collaborator

@bringlein bringlein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great. I had just two questions for my understanding.

Comment on lines +545 to +547
query: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size]
key: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size]
value: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we expect key and value to be padded, but not the query? What is the rational behind this interface? (if there is one, I'm fully aware this could also just be temporary)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and as @tdoublep pointed out, is there a way to support the flattened varlen format?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The query at the input is "flat" [num_tokens, num_heads, head_size]. The query gets padded inside the code ( lines 573-577).


# Compiled attention on Spyre
output_spyre_t = self.attn_op(qt_spyre, k_spyre, vt_spyre, sm_scale_spyre, mask_spyre)
output_spyre = self.attn_op(q_spyre, k_spyre, v_spyre, self.scale, mask_spyre)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we actually start profiling the performance of the different versions?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

@tdoublep
Copy link
Copy Markdown
Collaborator

Could we re-open this PR against the new spyre-inference repo? They we can merge it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants