Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a0375d7
Integrated custom attention backend
bohnstingl Feb 27, 2026
11255ac
Formatting issues
bohnstingl Feb 27, 2026
89d5a75
Changed the name of the attention operation
bohnstingl Feb 27, 2026
bfbc64a
Changed filename
bohnstingl Feb 27, 2026
f8afb02
Implemented gather to avoid using full KV cache
bohnstingl Mar 3, 2026
df3ab2c
Removed .item() calls
bohnstingl Mar 3, 2026
8e0bd74
Cleanup and adding of example
bohnstingl Mar 3, 2026
90ce563
Lint
bohnstingl Mar 3, 2026
8d314b9
Added testcase for attention backend
bohnstingl Mar 3, 2026
0f34475
Added missing utils file
bohnstingl Mar 5, 2026
3bc3ee6
Reformat
bohnstingl Mar 6, 2026
c2d264b
Functional update
bohnstingl Mar 8, 2026
2e8e4aa
Lint issues
bohnstingl Mar 8, 2026
14b6ef7
:art: linting, vllm compatibility, test integration
joerunde Mar 9, 2026
c98a9a2
refactored attention backend to support compilation and execution on …
jvlunteren Mar 19, 2026
825a95c
formatting
jvlunteren Mar 20, 2026
a5c719f
add unit test
jvlunteren Mar 20, 2026
6da9be4
formatting
jvlunteren Mar 23, 2026
61e22f1
removed redundant code
jvlunteren Mar 24, 2026
2d8bb12
added empty line back
jvlunteren Mar 24, 2026
139ab4a
formatting
jvlunteren Mar 24, 2026
9bf1283
removed custom num_heads handling
jvlunteren Mar 24, 2026
0931224
removed compat_utils.py
jvlunteren Mar 25, 2026
621df53
renamed spyre_paged_attn.py to spyre_attn.py
jvlunteren Mar 25, 2026
2bf45c1
add dynamic=False argument to torch.compile
jvlunteren Mar 25, 2026
9919ba2
adapted test_spyre_attn.py to previous name change
jvlunteren Mar 25, 2026
a8c26f6
limit supported data types to float16
jvlunteren Mar 25, 2026
6338c32
limit supported kv cache data types to float16
jvlunteren Mar 25, 2026
e118284
removed redundant code
jvlunteren Mar 25, 2026
82d2daf
indicated if steps are executed on CPU and/or Spyre
jvlunteren Mar 25, 2026
781c095
renaming
jvlunteren Mar 25, 2026
c2556b3
further renaming
jvlunteren Mar 25, 2026
a2eadbd
use utils for transfers between cpu and spyre
jvlunteren Mar 25, 2026
0c6100f
various updates to test
jvlunteren Mar 25, 2026
d25ec3f
formatting
jvlunteren Mar 25, 2026
49b6109
WIP: reworked D2H movements
bohnstingl Mar 26, 2026
88e12ea
fixed supports_head_size()
jvlunteren Mar 26, 2026
17d2194
Merge branch 'pytorch_native_attention' of github.com:jvlunteren/vllm…
bohnstingl Mar 26, 2026
a13a657
Enforce dtype="float16"
bohnstingl Mar 26, 2026
91a24d6
Moved assert
bohnstingl Mar 26, 2026
dc5a07c
Corrected stripped attention test
bohnstingl Mar 26, 2026
80c7cc9
Updates to address review comments
bohnstingl Mar 26, 2026
8adae0a
Merge branch 'main' of github.com:vllm-project/vllm-spyre into pytorc…
bohnstingl Mar 30, 2026
af9d8f9
Integrated minor review findings
bohnstingl Mar 30, 2026
c6fd7f9
Merge branch 'main' of github.com:vllm-project/vllm-spyre into pytorc…
bohnstingl Apr 2, 2026
7882018
Integrated reviewer comments and suggestions
bohnstingl Apr 2, 2026
fef4e7f
Fixing formatting errors
bohnstingl Apr 3, 2026
3d3a169
Switched KV cache format to (num_blocks, 2, ...)
bohnstingl Apr 3, 2026
1749821
Removed outdated max_num_seqs==1 restriction
bohnstingl Apr 3, 2026
a3eecc5
Removed enforce_eager argument
bohnstingl Apr 9, 2026
3b02c2c
Merge branch 'main' into pytorch_native_attention
jvlunteren Apr 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions vllm_spyre_next/examples/torch_spyre_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def parse_args():
parser.add_argument("--model", type=str, default="ibm-ai-platform/micro-g3.3-8b-instruct-1b")
parser.add_argument("--max_model_len", "--max-model-len", type=int, default=2048)
parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2)
parser.add_argument("--max_num_batched_tokens", "--max-num-batched-tokens", type=int, default=2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would we want this to default to 2? this defines the number of tokens in the batch. 2 is super low, does it even work? we still use the base scheduler in vllm_spyre_next so the value is not overridden for our granite3.3-8b model as this was the case in vllm_spyre...

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

This will create a lot of chunked prefills for the example below I guess? Is this something that even works right now?

parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--num-prompts", "-n", type=int, default=3)
parser.add_argument(
Expand All @@ -34,6 +35,7 @@ def parse_args():
"This list is repeated until prompts are exhausted.",
)
parser.add_argument("--compare-with-cpu", action=argparse.BooleanOptionalAction)
parser.add_argument("--attention_backend", "--attention-backend", type=str, default=None)
parser.add_argument(
"--enforce_eager",
"--enforce-eager",
Expand Down Expand Up @@ -95,7 +97,11 @@ def main():
"Compose a LinkedIn post about your company's latest product release.",
]

prompts = [template.format(instr) for instr in instructions]
simple_prompt = [
"What are IBMs main businesses?",
]

prompts = simple_prompt + [template.format(instr) for instr in instructions]

prompts = prompts * (args.num_prompts // len(prompts) + 1)
prompts = prompts[0 : args.num_prompts]
Expand All @@ -111,6 +117,8 @@ def main():
# lazy import to switch between old an new platform:
# platform registration happens at import time
from vllm import LLM, SamplingParams
from vllm.config import AttentionConfig
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.config import CompilationConfig

sampling_params = [
Expand All @@ -124,10 +132,13 @@ def main():
max_model_len=args.max_model_len,
max_num_seqs=max_num_seqs,
tensor_parallel_size=args.tp,
max_num_batched_tokens=1024,
max_num_batched_tokens=args.max_num_batched_tokens,
dtype="float16",
enforce_eager=args.enforce_eager,
compilation_config=CompilationConfig(custom_ops=args.custom_ops),
attention_config=AttentionConfig(backend=AttentionBackendEnum[args.attention_backend])
if args.attention_backend is not None
else None,
)

# Generate texts from the prompts. The output is a list of RequestOutput objects
Expand Down
Loading
Loading