[Spyre-Next] Pytorch Native Attention on Spyre#853
[Spyre-Next] Pytorch Native Attention on Spyre#853jvlunteren merged 51 commits intotorch-spyre:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Joe Runde <joe@joerun.de>
…Spyre Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
3c15910 to
a5c719f
Compare
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
bohnstingl
left a comment
There was a problem hiding this comment.
I took a first look at the "infrastructure" surrounding the attention. I will next look at the attention implementation itself
| @@ -0,0 +1,64 @@ | |||
| ### TEST 1 - Disable prefix caching | |||
There was a problem hiding this comment.
+1: we do have the script vllm_spyre_next/examples/torch_spyre_inference.py which does precisely create an LLM() instance and calls LLM.generate. It is more configurable, copied it over from vllm_spyre. Could you try to include your changes there?
There was a problem hiding this comment.
I think yes, it is good to have examples. But I would just not use the capital O here...
There was a problem hiding this comment.
+1 on we should probably include this in the existing example, with a flag to enable spyre attention
There was a problem hiding this comment.
We integrated this example file into the already existing vllm_spyre_next/examples/torch_spyre_inference.py
There was a problem hiding this comment.
I think we should really try to reuse the tests from upstream here. This is a stripped version of the upstream https://github.com/vllm-project/vllm/blob/main/tests/v1/attention/test_attention_backends.py. @joerunde is there a way to enable this upstream test and integrate our backend there? This would avoid us having to maintain this file and the attention_test_utils.py
There was a problem hiding this comment.
Right, we should aim to reuse the upstream tests. Is the goal here to run the test_causal_backend_correctness test but with our AttentionBackendEnum.CUSTOM backend instead of with the full list of other supported backends?
We'll probably need to get the upstream test updated to correctly populate the supported attention backends based on the current platform, but in the meantime we could whip up something in our pytest plugin like this:
- rel_path: tests/v1/attention/test_attention_backends.py
allow_list:
- test: "test_causal_backend_correctness"
mode: mandatory_pass
tags: [attention]
fixtures:
- patch_backend_listwhere we'd inject the provided fixture(s) into the test, which we could implement like:
@pytest.fixture
def patch_backend_list(monkeypatch):
# import test_attention_backends from the cached vllm upstream tests
our_backend_list = [
AttentionBackendEnum.CUSTOM,
]
with monkeypatch.setattr(test_attention_backends, "BACKENDS_TO_TEST", our_backend_list)
yieldThere was a problem hiding this comment.
this PR is already huge and we are trying to get it in asap. Can we address this in a follow up?
There was a problem hiding this comment.
I have those changes ready here: jvlunteren#1
I can plop that into a separate PR into the main repo here so it's ready to merge (and delete these two files) once this is in
| # error further. The reference uses float32 softmax internally, widening | ||
| # the gap. rtol=5.0 is loose but expected; the results are numerically | ||
| # equivalent and the gap does not grow with model scale. | ||
| atol, rtol = 0.3, 5.0 |
There was a problem hiding this comment.
Are these tolerances okay? They seem a bit large to me. Torch-spyre has 0.1 absolute and relative tolerances for its attention tests, see https://github.com/torch-spyre/torch-spyre/blob/7b3b4dcfe83838c65e9ef1d3a6aedce2605ae7df/tests/inductor/test_building_blocks.py#L57
There was a problem hiding this comment.
If I recall correctly, I observed similar deviations when using SDPA during debugging. I initially attributed them to float16 precision, but I need to verify this.
| # errors grow with query_len * kv_len * head_size. | ||
| atol, rtol = 0.3, 5.0 | ||
| else: | ||
| atol, rtol = 0.1, 0.1 |
There was a problem hiding this comment.
If I recall correctly, I observed similar deviations when using SDPA during debugging. I initially attributed them to float16 precision, but I need to verify this.
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
bohnstingl
left a comment
There was a problem hiding this comment.
I made a first pass through the attention mechanism. In general it looks good to me, but I think we could do some cleanup and consolidation. Please let me know if something is unclear.
| cache_dtype_str: str = "auto", | ||
| ) -> tuple[int, ...]: | ||
| """KV cache shape: [2, num_blocks, block_size, num_kv_heads, head_size]""" | ||
| return (2, num_blocks, block_size, num_kv_heads, head_size) |
There was a problem hiding this comment.
@tdoublep was opting for an implementation following (num_blocks, 2, ...), is this still the case, or do we want to use this format? See #774 (comment)
There was a problem hiding this comment.
In the future, when we support kv cache transfer and offloading, will a specific layout be required?
There was a problem hiding this comment.
I would also argue for (num_blocks, 2, ...), are there any issues/constraints with it @jvlunteren ?
There was a problem hiding this comment.
I have changed the KV cache format to be (num_blocks, 2, ...).
@jvlunteren please have a look at my changes and confirm.
There was a problem hiding this comment.
I have changed the KV cache format to be (num_blocks, 2, ...).
@jvlunteren please have a look at my changes and confirm.
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
bot:next-test |
bohnstingl
left a comment
There was a problem hiding this comment.
In principle the PR looks good to me.
There are some questions related to the test files that maybe @joerunde or @tjohnson31415 could comment on and apart from that, I think we should just get some more reviews in.
|
@bohnstingl @jvlunteren I heard you like PRs so I made a PR to your PR: jvlunteren#1 Let me know what you think, if we want to go this route to remove the duplicated test files here and use the upstream tests directly then I can help clean that up / get any other test cases working that we need as well. |
bringlein
left a comment
There was a problem hiding this comment.
thanks for all your efforts @jvlunteren ! (and @bohnstingl )
Juts have a few comments
| @@ -0,0 +1,64 @@ | |||
| ### TEST 1 - Disable prefix caching | |||
There was a problem hiding this comment.
I think yes, it is good to have examples. But I would just not use the capital O here...
| cache_dtype_str: str = "auto", | ||
| ) -> tuple[int, ...]: | ||
| """KV cache shape: [2, num_blocks, block_size, num_kv_heads, head_size]""" | ||
| return (2, num_blocks, block_size, num_kv_heads, head_size) |
There was a problem hiding this comment.
I would also argue for (num_blocks, 2, ...), are there any issues/constraints with it @jvlunteren ?
| mask_values: Mask values tensor [num_heads * kv_len, num_heads * query_len_padded] | ||
| Pre-computed on CPU: 0.0 for valid, -65504.0 for masked/padded | ||
| """ | ||
| kq = k @ qt # [num_heads * kv_len, num_heads * query_len_padded] |
There was a problem hiding this comment.
The current implementation does heads sequentially (AFAIK), due to cache size limitations. So I think we can merge it as is, but we should have it in mind that we need to change this in the future?
|
|
||
| # Positions along query and KV dimensions | ||
| q_pos = torch.arange(max_query_len, device=device) # [max_query_len] | ||
| kv_pos = torch.arange(aligned_max_seq_len, device=device) # [aligned_max_seq_len] |
There was a problem hiding this comment.
Do we know if these operations are supported on spyre? @bohnstingl just to evaluate how to move more parts of the attention computation to the device in the future...
There was a problem hiding this comment.
Yes, torch.arange is supported on torch-spyre, see https://github.com/torch-spyre/torch-spyre/blob/a6caaf92e27925e892ba98629a32de694a8b8d9a/tests/inductor/test_inductor_ops.py#L1697-L1701. It should be supported in eager and in torch.compile-mode
| """Transposed attention for Spyre: handles all heads at once. | ||
|
|
||
| Args: | ||
| qt: Query transposed [head_size, num_heads * query_len_padded] |
There was a problem hiding this comment.
Do we know what the maximum and minimum supported tensor sizes here are? Maybe we should just annotate this as assert or comment for know? (Also regarding shapes to be dividable by 128/64/etc for stickification?)
| # Define all backend configurations of full cudagraph to be tested | ||
| full_cg_backend_configs = { | ||
| # FA3 on Hopper | ||
| "FA3": BackendConfig( |
There was a problem hiding this comment.
do we need all these tests for spyre here?
There was a problem hiding this comment.
I agree I think we'd only need to test our new attention backend.
I have changes ready to land once this PR is in to address this here: https://github.com/vllm-project/vllm-spyre/pull/884/changes#diff-035ed2d7ff0ff23d9dfc96c0f483503dbb4c16529e7ea88192e717621f4da830R632-R636
There was a problem hiding this comment.
Yes, absolutely. This file is an artifact from upstream in order to run the correctness tests and was just overtaken as-is basically. The PR from Joe makes much more sense for this.
| elif max_query_len >= 32: | ||
| atol, rtol = 0.3, 5.0 # float16 accumulation errors for large prompts | ||
| else: | ||
| atol, rtol = 0.2, 0.2 |
There was a problem hiding this comment.
just out of interest: Which values do we use currently?
There was a problem hiding this comment.
Basically, we hit every tolerance values in one way or another in the current tests.
yannicks1
left a comment
There was a problem hiding this comment.
disclaimer: only some initial comments. I am reviewing the actual backend in the afternoon. (just making sure no comments get lost)
| @@ -0,0 +1,64 @@ | |||
| ### TEST 1 - Disable prefix caching | |||
There was a problem hiding this comment.
+1: we do have the script vllm_spyre_next/examples/torch_spyre_inference.py which does precisely create an LLM() instance and calls LLM.generate. It is more configurable, copied it over from vllm_spyre. Could you try to include your changes there?
There was a problem hiding this comment.
this PR is already huge and we are trying to get it in asap. Can we address this in a follow up?
| @@ -0,0 +1,353 @@ | |||
| # SPDX-License-Identifier: Apache-2.0 | |||
There was a problem hiding this comment.
this seems like mostly copied from tests/v1/attention/utils.py, is there any fundamental changes (I did not do the diff) or could we just import the file from upstream vllm?
There was a problem hiding this comment.
Yes, the file was a direct copy from upstream. Initially I didn't want to link the upstream file, because we have to do it in a bit of a weird way, but I changed that now.
| return super().get_attn_backend_cls(selected_backend, *args, **kwargs) | ||
|
|
||
| @classmethod | ||
| def check_and_update_config(cls, vllm_config: VllmConfig) -> None: |
There was a problem hiding this comment.
I think we're missing some logic in here to set --max-num-seqs to 1 since we don't support batch sizes > 1 when AttentionBackendEnum.CUSTOM is enabled
There was a problem hiding this comment.
This logic has been added. Now when our CUSTOM attention backend is selected, the max_num_seqs is restricted to 1
There was a problem hiding this comment.
I think this is actually not necessary anymore and I removed it again.
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
I made some new commits and changed:
@jvlunteren could you please confirm both points above? |
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
| parser.add_argument("--model", type=str, default="ibm-ai-platform/micro-g3.3-8b-instruct-1b") | ||
| parser.add_argument("--max_model_len", "--max-model-len", type=int, default=2048) | ||
| parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2) | ||
| parser.add_argument("--max_num_batched_tokens", "--max-num-batched-tokens", type=int, default=2) |
There was a problem hiding this comment.
why would we want this to default to 2? this defines the number of tokens in the batch. 2 is super low, does it even work? we still use the base scheduler in vllm_spyre_next so the value is not overridden for our granite3.3-8b model as this was the case in vllm_spyre...
There was a problem hiding this comment.
+1
This will create a lot of chunked prefills for the example below I guess? Is this something that even works right now?
| parser.add_argument("--model", type=str, default="ibm-ai-platform/micro-g3.3-8b-instruct-1b") | ||
| parser.add_argument("--max_model_len", "--max-model-len", type=int, default=2048) | ||
| parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2) | ||
| parser.add_argument("--max_num_batched_tokens", "--max-num-batched-tokens", type=int, default=2) |
There was a problem hiding this comment.
+1
This will create a lot of chunked prefills for the example below I guess? Is this something that even works right now?
| output_all_seqs = torch.zeros_like(query) | ||
|
|
||
| # Process each sequence separately | ||
| for seq_idx in range(num_seqs): |
There was a problem hiding this comment.
One of the next steps should be thinking about how we handle varlen batches in a proper way on Spyre
|
@jvlunteren @bohnstingl Can we resolve conflicts and merge? |
Signed-off-by: Jan van Lunteren <161835099+jvlunteren@users.noreply.github.com>
Description
This PR builds on @bohnstingl’s branch (https://github.com/bohnstingl/vllm/tree/naive_attn_backend) and includes his commits. @bohnstingl is a co‑author of this work.
It introduces a vLLM v1 attention backend (
SpyreAttentionPagedBackend) that runs on Spyre hardware using only PyTorch native operations (matmul, softmax). The backend implements the standard vLLM v1 triple (AttentionBackend/AttentionImpl/AttentionMetadataBuilder) and processes each forward pass in six steps:slot_mapping)[num_seqs, seq_len, num_kv_heads, head_size]tensorsDue to current limitations in Spyre’s matmul support, attention is computed one sequence at a time rather than across the full batch. In addition, all attention heads are fused into a single 2D matmul by reshaping Q, K, and V from
[num_heads, seq_len, head_size]into[num_heads × seq_len, head_size]. A block‑diagonal attention mask is applied to zero out spurious cross‑head interactions introduced by the concatenated layout. This makes the result equivalent tonum_headsindependent attention computations, at the cost of a single larger matrix multiply. To match the hardware’s expected memory layout, the kernel “stickifies” inputs to the matmul operation using a doubletransposefollowed by.contiguous().Spyre requires static tensor shapes at compile time (
torch._dynamo.config.assume_static_by_default = True). To avoid a full recompile on every decode step, KV tensors are padded to the next multiple ofKV_LENGTH_ALIGNMENT = 256. This buckets sequence lengths into tiers (256, 512, 768, …) so that only the first request at each tier pays the compilation cost. Query tokens are chunked to a fixedQUERY_CHUNK_SIZE = 32for the same reason.Grouped-query attention is handled by
repeat_interleaveon the KV heads before the kernel, expandingnum_kv_headstonum_headsin one vectorized call.A
use_sdpaflag routes computation throughtorch.nn.functional.scaled_dot_product_attentioninstead of the Spyre kernel. This uses the semantic mask directly (no block-diagonal needed, sinceSDPAhandles heads independently) and is primarily used for CPU testing and debugging.Related Issues
Relates to #647
Test Plan
I created and executed a debug build that validates SDPA (CPU) outputs against the Spyre attention implementation for a range of models and prompts. The discrepancies were within acceptable bounds, accounting for Spyre’s float16 precision versus the precisions used by SDPA on the CPU. I also added a unit test (
test_spyre_paged_attn.py) for this validation, which is included in this PR.Checklist
bash format.sh)Signed-off-by:line (DCO compliance)@tdoublep @bohnstingl