[cpu][performance] CPU Paged Attention NEON BFMMLA BF16 Implementation#32263
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request introduces a new CPU attention backend for ARM NEON with BFMMLA instruction support, specifically targeting bfloat16 data types. The implementation promises significant performance improvements for both prefill and decode stages. The changes include a new ISA dispatch path, a highly optimized BFMMLA GEMM kernel, and custom data layouts for key and value caches to leverage the hardware capabilities. The code is well-structured, using if constexpr for compile-time specialization and providing new tests for the added functionality. My review found the implementation to be solid, with one suggestion to improve the robustness of a kernel function by adding an explicit check for an implicit assumption.
| if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: | ||
| elif current_platform.get_cpu_architecture() == CpuArchEnum.ARM: | ||
| if block_size % 128 == 0 and dtype == torch.bfloat16: | ||
| return "neon_bfmmla" |
There was a problem hiding this comment.
Missing BFMMLA hardware capability check may cause crashes
Medium Severity
The _get_attn_isa function selects "neon_bfmmla" based solely on ARM architecture, block size alignment, and bfloat16 dtype, without checking if the CPU actually supports the BFMMLA instruction (ARMv8.6-A FEAT_BF16MM extension). This is inconsistent with how AMX is handled, which has an explicit torch._C._cpu._is_amx_tile_supported() check. On ARM64 CPUs that support BF16 conversions but lack BFMMLA (such as Apple Silicon M1/M2/M3 or AWS Graviton2), selecting this ISA would cause an illegal instruction crash at runtime.
8efd7e7 to
a15fb20
Compare
There was a problem hiding this comment.
Great work! Thank you :)
I put in some initial comments, around:
- do we need to introduce NEON_BFMMLA as a new ISA - can't your BFMMLA implementation can be just under the existing NEON isa?
- with the BFMMLA implementation we don't need to (and shouldn't) worry about any types other than BF16 - so please specialize for
c10::BFloat16and notscalar_t. Once this is done, the code will be much clearer/simpler and we'll go through another round of reviews.
a15fb20 to
ea256d0
Compare
|
Hi @gassan-arm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
ea256d0 to
5008f62
Compare
csrc/cpu/cpu_attn_neon_bfmmla.hpp
Outdated
| float* C_blk = C + m * ldc; | ||
|
|
||
| #define DISPATCH_MB(mb) \ | ||
| gemm_packA_compute_MB_xN<mb, N, K, BFMMLABLayout::TokenColumn>( \ |
There was a problem hiding this comment.
For QK phase, why don't we just pack the query in copy_q_heads_tile while it's hot in cache, similar to what we do for AMX?
For PV phase, is the cost to pack P amortized?
There was a problem hiding this comment.
we can revisit this later in a future PR
5008f62 to
a1d3bd6
Compare
fadara01
left a comment
There was a problem hiding this comment.
LGTM, could you please just test on a machine with no BF16 HW (e.g. c6g) to make sure this works as expected?
|
@bigPYJ1151 could you please take a look at this? |
a1d3bd6 to
75339ab
Compare
please see: #32932 |
aditew01
left a comment
There was a problem hiding this comment.
Overall, neat changes. Thanks.
Nit: it'd be great if you could add comments with TODO: if there's something which needs to be addressed and is not in the scope for the PR
fadara01
left a comment
There was a problem hiding this comment.
Actually, let's hold-off merging this. I'm seeing regressions with end to end throughout benchmarks.
I'll add more details in a bit.
75339ab to
6c65fa7
Compare
|
@bigPYJ1151 could you please review and hopefully merge this :) |
|
This pull request has merge conflicts that must be resolved before it can be |
6c65fa7 to
a9e44c6
Compare
|
Hi, |
|
Hi @gassan-arm, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
Implementation of paged attention using BFMMLA for increased BF16 performance Co-authored-by: GitHub Copilot Signed-off-by: Gassan <gassan.salama@arm.com>
a9e44c6 to
eb17631
Compare
|
Looks like ARM testing is broken, but should not relate to this PR. |
vllm-project#32263) Signed-off-by: Gassan <gassan.salama@arm.com>
vllm-project#32263) Signed-off-by: Gassan <gassan.salama@arm.com>
Purpose
CPU Paged Attention NEON BFMMLA BF16 Implementation
Co-authored-by: GitHub Copilot
Test Results
Using: #31720 Benchmark Suite,
Against Current NEON Implementation:
Prefill: 2.32x
Decode: 2.07x
cc. @aditew01 @fadara01