Skip to content

[Spyre-Next] Pytorch Native Attention on Spyre#853

Merged
jvlunteren merged 51 commits intotorch-spyre:mainfrom
jvlunteren:pytorch_native_attention
Apr 13, 2026
Merged

[Spyre-Next] Pytorch Native Attention on Spyre#853
jvlunteren merged 51 commits intotorch-spyre:mainfrom
jvlunteren:pytorch_native_attention

Conversation

@jvlunteren
Copy link
Copy Markdown
Collaborator

@jvlunteren jvlunteren commented Mar 20, 2026

Description

This PR builds on @bohnstingl’s branch (https://github.com/bohnstingl/vllm/tree/naive_attn_backend) and includes his commits. @bohnstingl is a co‑author of this work.

It introduces a vLLM v1 attention backend (SpyreAttentionPagedBackend) that runs on Spyre hardware using only PyTorch native operations (matmul, softmax). The backend implements the standard vLLM v1 triple (AttentionBackend / AttentionImpl / AttentionMetadataBuilder) and processes each forward pass in six steps:

  1. Write new key/value tokens into the paged KV cache via vectorized scatter (slot_mapping)
  2. Gather the relevant KV entries from scattered blocks into compact [num_seqs, seq_len, num_kv_heads, head_size] tensors
  3. Reshape the flat query token batch into per-sequence padded tensors
  4. Build a per-sequence boolean attention mask encoding padding validity and (for prefill) causality
  5. Compute attention per sequence with query chunking
  6. Extract only the real (non-padded) output tokens back to the flat batch layout

Due to current limitations in Spyre’s matmul support, attention is computed one sequence at a time rather than across the full batch. In addition, all attention heads are fused into a single 2D matmul by reshaping Q, K, and V from [num_heads, seq_len, head_size] into [num_heads × seq_len, head_size]. A block‑diagonal attention mask is applied to zero out spurious cross‑head interactions introduced by the concatenated layout. This makes the result equivalent to num_heads independent attention computations, at the cost of a single larger matrix multiply. To match the hardware’s expected memory layout, the kernel “stickifies” inputs to the matmul operation using a double transpose followed by .contiguous().

Spyre requires static tensor shapes at compile time (torch._dynamo.config.assume_static_by_default = True). To avoid a full recompile on every decode step, KV tensors are padded to the next multiple of KV_LENGTH_ALIGNMENT = 256. This buckets sequence lengths into tiers (256, 512, 768, …) so that only the first request at each tier pays the compilation cost. Query tokens are chunked to a fixed QUERY_CHUNK_SIZE = 32 for the same reason.

Grouped-query attention is handled by repeat_interleave on the KV heads before the kernel, expanding num_kv_heads to num_heads in one vectorized call.

A use_sdpa flag routes computation through torch.nn.functional.scaled_dot_product_attention instead of the Spyre kernel. This uses the semantic mask directly (no block-diagonal needed, since SDPA handles heads independently) and is primarily used for CPU testing and debugging.

Related Issues

Relates to #647

Test Plan

I created and executed a debug build that validates SDPA (CPU) outputs against the Spyre attention implementation for a range of models and prompts. The discrepancies were within acceptable bounds, accounting for Spyre’s float16 precision versus the precisions used by SDPA on the CPU. I also added a unit test (test_spyre_paged_attn.py) for this validation, which is included in this PR.

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

@tdoublep @bohnstingl

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

bohnstingl and others added 17 commits March 23, 2026 09:30
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Joe Runde <joe@joerun.de>
…Spyre

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@jvlunteren jvlunteren force-pushed the pytorch_native_attention branch from 3c15910 to a5c719f Compare March 23, 2026 08:43
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@bohnstingl bohnstingl self-requested a review March 23, 2026 20:38
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a first look at the "infrastructure" surrounding the attention. I will next look at the attention implementation itself

@@ -0,0 +1,64 @@
### TEST 1 - Disable prefix caching
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we okay with adding this file? I would be, but I just want to confirm with others @joerunde @tdoublep

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1: we do have the script vllm_spyre_next/examples/torch_spyre_inference.py which does precisely create an LLM() instance and calls LLM.generate. It is more configurable, copied it over from vllm_spyre. Could you try to include your changes there?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes, it is good to have examples. But I would just not use the capital O here...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on we should probably include this in the existing example, with a flag to enable spyre attention

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We integrated this example file into the already existing vllm_spyre_next/examples/torch_spyre_inference.py

Comment thread vllm_spyre_next/tests/test_vllm_spyre_next.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should really try to reuse the tests from upstream here. This is a stripped version of the upstream https://github.com/vllm-project/vllm/blob/main/tests/v1/attention/test_attention_backends.py. @joerunde is there a way to enable this upstream test and integrate our backend there? This would avoid us having to maintain this file and the attention_test_utils.py

Copy link
Copy Markdown
Collaborator

@joerunde joerunde Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, we should aim to reuse the upstream tests. Is the goal here to run the test_causal_backend_correctness test but with our AttentionBackendEnum.CUSTOM backend instead of with the full list of other supported backends?

We'll probably need to get the upstream test updated to correctly populate the supported attention backends based on the current platform, but in the meantime we could whip up something in our pytest plugin like this:

    - rel_path: tests/v1/attention/test_attention_backends.py
      allow_list:
        - test: "test_causal_backend_correctness"
          mode: mandatory_pass
          tags: [attention]
          fixtures:
            - patch_backend_list

where we'd inject the provided fixture(s) into the test, which we could implement like:

@pytest.fixture
def patch_backend_list(monkeypatch):
    # import test_attention_backends from the cached vllm upstream tests

    our_backend_list = [
        AttentionBackendEnum.CUSTOM,
    ]

    with monkeypatch.setattr(test_attention_backends, "BACKENDS_TO_TEST", our_backend_list)
        yield

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR is already huge and we are trying to get it in asap. Can we address this in a follow up?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have those changes ready here: jvlunteren#1

I can plop that into a separate PR into the main repo here so it's ready to merge (and delete these two files) once this is in

# error further. The reference uses float32 softmax internally, widening
# the gap. rtol=5.0 is loose but expected; the results are numerically
# equivalent and the gap does not grow with model scale.
atol, rtol = 0.3, 5.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these tolerances okay? They seem a bit large to me. Torch-spyre has 0.1 absolute and relative tolerances for its attention tests, see https://github.com/torch-spyre/torch-spyre/blob/7b3b4dcfe83838c65e9ef1d3a6aedce2605ae7df/tests/inductor/test_building_blocks.py#L57

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, I observed similar deviations when using SDPA during debugging. I initially attributed them to float16 precision, but I need to verify this.

# errors grow with query_len * kv_len * head_size.
atol, rtol = 0.3, 5.0
else:
atol, rtol = 0.1, 0.1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, I observed similar deviations when using SDPA during debugging. I initially attributed them to float16 precision, but I need to verify this.

Comment thread vllm_spyre_next/vllm_spyre_next/compat_utils.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/platform.py
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@bohnstingl bohnstingl self-requested a review March 24, 2026 16:50
Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a first pass through the attention mechanism. In general it looks good to me, but I think we could do some cleanup and consolidation. Please let me know if something is unclear.

Comment thread vllm_spyre_next/vllm_spyre_next/compat_utils.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
"""KV cache shape: [2, num_blocks, block_size, num_kv_heads, head_size]"""
return (2, num_blocks, block_size, num_kv_heads, head_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep was opting for an implementation following (num_blocks, 2, ...), is this still the case, or do we want to use this format? See #774 (comment)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, when we support kv cache transfer and offloading, will a specific layout be required?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also argue for (num_blocks, 2, ...), are there any issues/constraints with it @jvlunteren ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the KV cache format to be (num_blocks, 2, ...).
@jvlunteren please have a look at my changes and confirm.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the KV cache format to be (num_blocks, 2, ...).
@jvlunteren please have a look at my changes and confirm.

Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py Outdated
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl
Copy link
Copy Markdown
Collaborator

bot:next-test

Copy link
Copy Markdown
Collaborator

@bohnstingl bohnstingl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle the PR looks good to me.
There are some questions related to the test files that maybe @joerunde or @tjohnson31415 could comment on and apart from that, I think we should just get some more reviews in.

cc @tdoublep @yannicks1 @bringlein @maxdebayser

@joerunde
Copy link
Copy Markdown
Collaborator

joerunde commented Apr 1, 2026

@bohnstingl @jvlunteren I heard you like PRs so I made a PR to your PR: jvlunteren#1

Let me know what you think, if we want to go this route to remove the duplicated test files here and use the upstream tests directly then I can help clean that up / get any other test cases working that we need as well.

Copy link
Copy Markdown
Collaborator

@bringlein bringlein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for all your efforts @jvlunteren ! (and @bohnstingl )
Juts have a few comments

@@ -0,0 +1,64 @@
### TEST 1 - Disable prefix caching
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think yes, it is good to have examples. But I would just not use the capital O here...

Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py
Comment thread vllm_spyre_next/vllm_spyre_next/v1/attention/backends/spyre_attn.py Outdated
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
"""KV cache shape: [2, num_blocks, block_size, num_kv_heads, head_size]"""
return (2, num_blocks, block_size, num_kv_heads, head_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also argue for (num_blocks, 2, ...), are there any issues/constraints with it @jvlunteren ?

mask_values: Mask values tensor [num_heads * kv_len, num_heads * query_len_padded]
Pre-computed on CPU: 0.0 for valid, -65504.0 for masked/padded
"""
kq = k @ qt # [num_heads * kv_len, num_heads * query_len_padded]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation does heads sequentially (AFAIK), due to cache size limitations. So I think we can merge it as is, but we should have it in mind that we need to change this in the future?


# Positions along query and KV dimensions
q_pos = torch.arange(max_query_len, device=device) # [max_query_len]
kv_pos = torch.arange(aligned_max_seq_len, device=device) # [aligned_max_seq_len]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if these operations are supported on spyre? @bohnstingl just to evaluate how to move more parts of the attention computation to the device in the future...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, torch.arange is supported on torch-spyre, see https://github.com/torch-spyre/torch-spyre/blob/a6caaf92e27925e892ba98629a32de694a8b8d9a/tests/inductor/test_inductor_ops.py#L1697-L1701. It should be supported in eager and in torch.compile-mode

"""Transposed attention for Spyre: handles all heads at once.

Args:
qt: Query transposed [head_size, num_heads * query_len_padded]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know what the maximum and minimum supported tensor sizes here are? Maybe we should just annotate this as assert or comment for know? (Also regarding shapes to be dividable by 128/64/etc for stickification?)

# Define all backend configurations of full cudagraph to be tested
full_cg_backend_configs = {
# FA3 on Hopper
"FA3": BackendConfig(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need all these tests for spyre here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree I think we'd only need to test our new attention backend.

I have changes ready to land once this PR is in to address this here: https://github.com/vllm-project/vllm-spyre/pull/884/changes#diff-035ed2d7ff0ff23d9dfc96c0f483503dbb4c16529e7ea88192e717621f4da830R632-R636

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, absolutely. This file is an artifact from upstream in order to run the correctness tests and was just overtaken as-is basically. The PR from Joe makes much more sense for this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

elif max_query_len >= 32:
atol, rtol = 0.3, 5.0 # float16 accumulation errors for large prompts
else:
atol, rtol = 0.2, 0.2
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just out of interest: Which values do we use currently?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we hit every tolerance values in one way or another in the current tests.

Copy link
Copy Markdown
Collaborator

@yannicks1 yannicks1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

disclaimer: only some initial comments. I am reviewing the actual backend in the afternoon. (just making sure no comments get lost)

@@ -0,0 +1,64 @@
### TEST 1 - Disable prefix caching
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1: we do have the script vllm_spyre_next/examples/torch_spyre_inference.py which does precisely create an LLM() instance and calls LLM.generate. It is more configurable, copied it over from vllm_spyre. Could you try to include your changes there?

Comment thread vllm_spyre_next/vllm_spyre_next/platform.py Outdated
Comment thread vllm_spyre_next/vllm_spyre_next/platform.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR is already huge and we are trying to get it in asap. Can we address this in a follow up?

@@ -0,0 +1,353 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like mostly copied from tests/v1/attention/utils.py, is there any fundamental changes (I did not do the diff) or could we just import the file from upstream vllm?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the file was a direct copy from upstream. Initially I didn't want to link the upstream file, because we have to do it in a bit of a weird way, but I changed that now.

return super().get_attn_backend_cls(selected_backend, *args, **kwargs)

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're missing some logic in here to set --max-num-seqs to 1 since we don't support batch sizes > 1 when AttentionBackendEnum.CUSTOM is enabled

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic has been added. Now when our CUSTOM attention backend is selected, the max_num_seqs is restricted to 1

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually not necessary anymore and I removed it again.

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl
Copy link
Copy Markdown
Collaborator

bohnstingl commented Apr 3, 2026

I made some new commits and changed:

  1. The layout of the KV cache, it is now (num_blocks, 2, ...)
  2. I revisited the logic about max_num_seqs==1 and I think this was an old limitation that has meanwhile be resolved. The attention backend should support max_num_seqs > 1
  3. I tried to port the sdpa path to spyre, but torch-spyre currently supports not all necessary cases. In particular, there are limitations for GQA and for non-squared attention computation, see https://github.com/torch-spyre/torch-spyre/blob/cb711cbd6a33be088e5cbd3396178e23128c5252/tests/inductor/test_inductor_ops.py#L1472-L1513

@jvlunteren could you please confirm both points above?

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
parser.add_argument("--model", type=str, default="ibm-ai-platform/micro-g3.3-8b-instruct-1b")
parser.add_argument("--max_model_len", "--max-model-len", type=int, default=2048)
parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2)
parser.add_argument("--max_num_batched_tokens", "--max-num-batched-tokens", type=int, default=2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would we want this to default to 2? this defines the number of tokens in the batch. 2 is super low, does it even work? we still use the base scheduler in vllm_spyre_next so the value is not overridden for our granite3.3-8b model as this was the case in vllm_spyre...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

This will create a lot of chunked prefills for the example below I guess? Is this something that even works right now?

Copy link
Copy Markdown
Collaborator

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

parser.add_argument("--model", type=str, default="ibm-ai-platform/micro-g3.3-8b-instruct-1b")
parser.add_argument("--max_model_len", "--max-model-len", type=int, default=2048)
parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2)
parser.add_argument("--max_num_batched_tokens", "--max-num-batched-tokens", type=int, default=2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

This will create a lot of chunked prefills for the example below I guess? Is this something that even works right now?

output_all_seqs = torch.zeros_like(query)

# Process each sequence separately
for seq_idx in range(num_seqs):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the next steps should be thinking about how we handle varlen batches in a proper way on Spyre

@tdoublep
Copy link
Copy Markdown
Collaborator

@jvlunteren @bohnstingl Can we resolve conflicts and merge?

Signed-off-by: Jan van Lunteren <161835099+jvlunteren@users.noreply.github.com>
@jvlunteren jvlunteren merged commit d1697e8 into torch-spyre:main Apr 13, 2026
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants