[Spyre-Next] Integrated custom attention backend#774
[Spyre-Next] Integrated custom attention backend#774bohnstingl wants to merge 10 commits intomainfrom
Conversation
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
👋 Hi! Thank you for contributing to vLLM support on Spyre. We also recommend installing prek and configuring it to check your code before every local commit. |
|
Below is a small script to test the attention backend: from vllm import LLM, SamplingParams
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.config import AttentionConfig
def print_outputs(outputs, engine):
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Generated text: {generated_text!r}")
print("-" * 50)
for m in engine.llm_engine.get_metrics():
if "cache" in m.name:
print(m.name, m.value)
def main():
MODEL = "ibm-granite/granite-3.3-8b-instruct"
# Sampling parameter for the inference process
sampling_params = SamplingParams(max_tokens=5, # Maximum number of tokens to produce
)
# Prompts to use for inference
prompts = [
"What are IBMs main businesses?",
]
engine = LLM(
model=MODEL, # Model to use for inference.
gpu_memory_utilization=0.9, # By increasing utilization, you can provide more KV cache space.
enable_prefix_caching=True, # Flag determining whether prefix caching is enabled or disabled.
# enforce_eager=True, # Flag determinin whether eager mode or torch.compile should be used.
disable_log_stats=False, ## stats
attention_config=AttentionConfig(backend=AttentionBackendEnum.CUSTOM) # Select the new custom attention backend
)
# Generate response for prompt 0
outputs = engine.generate(prompts[0], sampling_params)
print_outputs(outputs, engine)
if __name__ == "__main__":
main() |
| cache_dtype_str: str = "auto", | ||
| ) -> tuple[int, ...]: | ||
| """KV cache shape: [2, num_blocks, block_size, num_kv_heads, head_size]""" | ||
| return (2, num_blocks, block_size, num_kv_heads, head_size) |
There was a problem hiding this comment.
Do we really want to use this shape for KV cache? In my experience this layout is pretty annoying actually. We could opt to follow the (num_blocks, 2, ...) layout instead
There was a problem hiding this comment.
I have not thought about the impact of the KV cache shape, tbh. If you prefer (num_blocks, 2, ...) instead and have good reasons, we can certainly change it
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
I updated by
|
|
thanks @bohnstingl! General questions:
|
|
For me 1 could be an option, yes. However, I have not yet tried this backend with torch-spyre and it will likely fail. In parallel, we are also working on a different attention backend which should be compatible with torch-spyre. We may want to have that as the default and the one from this PR maybe at a later stage. Regarding 2, I wouldn't do that tbh. The example I added was just meant to showcase the usage with |
Ah, interesting. I'm a little confused then about why we'd want to merge this in then, don't we want the code here to all work and run with torch-spyre?
Agreed, and we should probably start integrating upstream tests soon. Is your plan to put the pytorch native implementation here, or to put it upstream and consume it here? If we do want to merge a working implementation here that runs with torch-spyre, then I think at least as a stop-gap we should have a small test here covering it while we work towards consuming upstream tests |
The idea with this PR would be that we have the proper skeleton to use, which also allows us to switch out the attention backend against a different one. In a sense this attention formulation sticks closely to vLLM and may contain some operations that are not yet supported in torch-spyre and the one that @jvlunteren is working on is supported on torch-spyre, but may not be properly supported by vLLM. That's the trade-off. In the short-term, the attention backend from Jan, also tracked here, may be the way to go, but in the long-run, we will need a "paged" version of it for performance reasons.
I don't think that the pytorch native attention backend from here has any chance in landing in upstream vLLM. So it will just reside here. The tests however, should be consumed from upstream I think and I fully agree with you - we should start working on that. If you want, I can strip out one correctness test from upstream and include it here as a test to cover the gap until we have a proper way to consume the upstream tests? EDIT: I added a stripped testcase from upstream to check for the attention correctness. |
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
Signed-off-by: Thomas Ortner <boh@zurich.ibm.com>
|
Closing in favor of #798 |
Description
This PR provides the sekeleton to add new attention backends to the vllm_spyre_next plugin.
It uses the first draft implementation of the pytorch-native paged attention as an example
Related Issues
#648
Checklist
bash format.sh)Signed-off-by:line (DCO compliance)cc @tdoublep @dilipgb