Skip to content

[Spyre-Next] Integrated custom attention backend#774

Closed
bohnstingl wants to merge 10 commits intomainfrom
pytorch_paged_attn
Closed

[Spyre-Next] Integrated custom attention backend#774
bohnstingl wants to merge 10 commits intomainfrom
pytorch_paged_attn

Conversation

@bohnstingl
Copy link
Copy Markdown
Collaborator

Description

This PR provides the sekeleton to add new attention backends to the vllm_spyre_next plugin.
It uses the first draft implementation of the pytorch-native paged attention as an example

Related Issues

#648

Checklist

  • I have read the contributing guidelines
  • My code follows the project's code style (run bash format.sh)
  • I have added tests for my changes (if applicable)
  • I have updated the documentation (if applicable)
  • My commits include a Signed-off-by: line (DCO compliance)

cc @tdoublep @dilipgb

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@github-actions github-actions Bot changed the title Integrated custom attention backend [Spyre-Next] Integrated custom attention backend Feb 27, 2026
@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to vLLM support on Spyre.
Just a reminder: Make sure that your code passes all the linting checks, otherwise your PR won't be able to be merged. To do so, run ./format.sh.
Now you are good to go 🚀.

We also recommend installing prek and configuring it to check your code before every local commit.

@bohnstingl bohnstingl requested a review from tdoublep February 27, 2026 11:44
@bohnstingl bohnstingl self-assigned this Feb 27, 2026
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

Below is a small script to test the attention backend:

from vllm import LLM, SamplingParams
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.config import AttentionConfig

def print_outputs(outputs, engine):
    print("-" * 50)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Generated text: {generated_text!r}")
        print("-" * 50)
    for m in engine.llm_engine.get_metrics():
        if "cache" in m.name:
            print(m.name, m.value)


def main():
    MODEL = "ibm-granite/granite-3.3-8b-instruct" 
    
    # Sampling parameter for the inference process
    sampling_params = SamplingParams(max_tokens=5, # Maximum number of tokens to produce
                                     )

    # Prompts to use for inference
    prompts = [
        "What are IBMs main businesses?",
    ]
    
    engine = LLM(
        model=MODEL,                     # Model to use for inference.
        gpu_memory_utilization=0.9,      # By increasing utilization, you can provide more KV cache space.
        enable_prefix_caching=True,      # Flag determining whether prefix caching is enabled or disabled.
        # enforce_eager=True,              # Flag determinin whether eager mode or torch.compile should be used.
        disable_log_stats=False,  ## stats
        attention_config=AttentionConfig(backend=AttentionBackendEnum.CUSTOM) # Select the new custom attention backend
    )

    # Generate response for prompt 0
    outputs = engine.generate(prompts[0], sampling_params)
    print_outputs(outputs, engine)

if __name__ == "__main__":
    main()

Comment thread vllm_spyre_next/vllm_spyre_next/platform.py Outdated
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
"""KV cache shape: [2, num_blocks, block_size, num_kv_heads, head_size]"""
return (2, num_blocks, block_size, num_kv_heads, head_size)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to use this shape for KV cache? In my experience this layout is pretty annoying actually. We could opt to follow the (num_blocks, 2, ...) layout instead

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not thought about the impact of the KV cache shape, tbh. If you prefer (num_blocks, 2, ...) instead and have good reasons, we can certainly change it

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl bohnstingl marked this pull request as ready for review March 3, 2026 00:00
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

I updated by

  • Implementing a method to select the relevant blocks from the KV cache instead of using the entire KV cache -> a step closer towards paged attention
  • Removing unnecessary .item() calls which would cause trouble with torch.compile
  • Introducing an example how to use the attention mechanism with a model

@joerunde
Copy link
Copy Markdown
Collaborator

joerunde commented Mar 3, 2026

thanks @bohnstingl!

General questions:

  1. Should we update the plugin to make this attention backend the default? This would make it easier to run upstream tests on it without modification, for example
  2. Can we turn the example here into a small test case? That way we can be sure the attention backend continues to compile and run with each PR

@bohnstingl
Copy link
Copy Markdown
Collaborator Author

For me 1 could be an option, yes. However, I have not yet tried this backend with torch-spyre and it will likely fail. In parallel, we are also working on a different attention backend which should be compatible with torch-spyre. We may want to have that as the default and the one from this PR maybe at a later stage.

Regarding 2, I wouldn't do that tbh. The example I added was just meant to showcase the usage with LLM(...). Upstream vLLM has correctness tests for the attention backends. I integrated an earlier version of this backend into those correctness tests (https://github.com/bohnstingl/vllm/blob/bb0782267790fc9c479286e4772b28b7925bf34d/tests/v1/attention/test_attention_backends.py#L45) and they passed. So I would rather want to run those upstream tests with the backends that we develop here integrated. WDYT?

@joerunde
Copy link
Copy Markdown
Collaborator

joerunde commented Mar 3, 2026

However, I have not yet tried this backend with torch-spyre and it will likely fail

Ah, interesting. I'm a little confused then about why we'd want to merge this in then, don't we want the code here to all work and run with torch-spyre?

So I would rather want to run those upstream tests with the backends that we develop here integrated. WDYT?

Agreed, and we should probably start integrating upstream tests soon. Is your plan to put the pytorch native implementation here, or to put it upstream and consume it here?

If we do want to merge a working implementation here that runs with torch-spyre, then I think at least as a stop-gap we should have a small test here covering it while we work towards consuming upstream tests

@bohnstingl
Copy link
Copy Markdown
Collaborator Author

bohnstingl commented Mar 3, 2026

Ah, interesting. I'm a little confused then about why we'd want to merge this in then, don't we want the code here to all work and run with torch-spyre?

The idea with this PR would be that we have the proper skeleton to use, which also allows us to switch out the attention backend against a different one. In a sense this attention formulation sticks closely to vLLM and may contain some operations that are not yet supported in torch-spyre and the one that @jvlunteren is working on is supported on torch-spyre, but may not be properly supported by vLLM. That's the trade-off. In the short-term, the attention backend from Jan, also tracked here, may be the way to go, but in the long-run, we will need a "paged" version of it for performance reasons.

Agreed, and we should probably start integrating upstream tests soon. Is your plan to put the pytorch native implementation here, or to put it upstream and consume it here?

If we do want to merge a working implementation here that runs with torch-spyre, then I think at least as a stop-gap we should have a small test here covering it while we work towards consuming upstream tests

I don't think that the pytorch native attention backend from here has any chance in landing in upstream vLLM. So it will just reside here. The tests however, should be consumed from upstream I think and I fully agree with you - we should start working on that. If you want, I can strip out one correctness test from upstream and include it here as a test to cover the gap until we have a proper way to consume the upstream tests?

EDIT: I added a stripped testcase from upstream to check for the attention correctness.

Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
@bohnstingl
Copy link
Copy Markdown
Collaborator Author

Closing in favor of #798

@bohnstingl bohnstingl closed this Mar 6, 2026
@bohnstingl bohnstingl deleted the pytorch_paged_attn branch March 6, 2026 09:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants