[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel#914
[Spyre-Next] Pytorch Native Attention on Spyre: 4D Attention Kernel#914jvlunteren wants to merge 54 commits intotorch-spyre:mainfrom
Conversation
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Joe Runde <joe@joerun.de>
…Spyre Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
…-spyre into pytorch_native_attention
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <161835099+jvlunteren@users.noreply.github.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
Signed-off-by: Jan van Lunteren <161835099+jvlunteren@users.noreply.github.com>
|
|
||
| Prepares tensors on CPU (reshape, stickify, build mask), transfers to | ||
| Spyre for the compiled matmul kernel, then transfers the result back. | ||
| # Q: [B, padQ, num_heads, D] -> [B, num_heads, padQ, D] |
There was a problem hiding this comment.
Batch size (number of sequences)
There was a problem hiding this comment.
The shapes listed in the comments were originally based on shortened variable names to keep the comments brief and within the line-width limit. For clarity, I have now replaced these abbreviated names with the full variable names used in the code.
There was a problem hiding this comment.
OK, but in vLLM we should never have a batch size dimension in that way? Everything should be "flat"?
There was a problem hiding this comment.
The query argument in the forward method in line 240 has a "flat" vLLM v1 shape [num_tokens, num_heads, head_size].
This gets converted in line 283 to [num_seqs, max_query_len, num_heads, head_size] in order to be able to use torch.matmul.
In line 310 the output is converted back into the "flat" vLLM v1 shape [num_actual_tokens, num_heads, head_size].
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
bringlein
left a comment
There was a problem hiding this comment.
looks great. I had just two questions for my understanding.
| query: torch.Tensor, # [num_seqs, max_query_len, num_heads, head_size] | ||
| key: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size] | ||
| value: torch.Tensor, # [num_seqs, aligned_max_seq_len, num_kv_heads, head_size] |
There was a problem hiding this comment.
so we expect key and value to be padded, but not the query? What is the rational behind this interface? (if there is one, I'm fully aware this could also just be temporary)
There was a problem hiding this comment.
and as @tdoublep pointed out, is there a way to support the flattened varlen format?
There was a problem hiding this comment.
The query at the input is "flat" [num_tokens, num_heads, head_size]. The query gets padded inside the code ( lines 573-577).
|
|
||
| # Compiled attention on Spyre | ||
| output_spyre_t = self.attn_op(qt_spyre, k_spyre, vt_spyre, sm_scale_spyre, mask_spyre) | ||
| output_spyre = self.attn_op(q_spyre, k_spyre, v_spyre, self.scale, mask_spyre) |
There was a problem hiding this comment.
can we actually start profiling the performance of the different versions?
|
Could we re-open this PR against the new spyre-inference repo? They we can merge it. |
Description
This PR extends PR #853 by replacing the 2D transposed attention kernel with a 4D broadcast matmul kernel, eliminating per‑sequence and per‑chunk loops, GQA head duplication, and block‑diagonal masking.
Related Issues
Relates to #647
Test Plan
Same approach as in PR #853.
Checklist
bash format.sh)Signed-off-by:line (DCO compliance)