Skip to content

fp8 support#54

Closed
endurehero wants to merge 31 commits intodeepseek-ai:mainfrom
endurehero:will_fp8_mr
Closed

fp8 support#54
endurehero wants to merge 31 commits intodeepseek-ai:mainfrom
endurehero:will_fp8_mr

Conversation

@endurehero
Copy link

@endurehero endurehero commented Feb 28, 2025

Functionality

Support FP8 WGMMA based on the async pipeline design of FlashMLA. The TransV part draws on the implementation of SmemTranspose64x64 in Fa3.
Currently, Q/K/V only support symmetric PerTensor quantization. Since the maximum value of P does not exceed 1, the f32tofp8_cast is directly used for quantization.

Performance

cuda driver version: 535.183.06
nvcc version: 12.8
torch version: 2.6

On the H20, MLA typically demonstrate a high degree of arithmetic intensity. Consequently, the Memory Floating - point Utilization (MFU) is employed as a performance metric.
image

On the H800, MLA typically encounter memory-bound situations. Consequently, the Memory Bandwidth Utilization (MBU) metric is adopted to evaluate the performance of the kernel. There is still a lot of room for optimization on the H800. Look forward to working together.
image

Reproduction

python3 ./tests/test_flash_mla.py --dtype e4m3

@endurehero endurehero closed this Feb 28, 2025
@endurehero endurehero changed the title support fp8 fp8 support Feb 28, 2025
@endurehero endurehero reopened this Feb 28, 2025
@endurehero endurehero mentioned this pull request Feb 28, 2025
@sijiac
Copy link
Contributor

sijiac commented Mar 1, 2025

awesome, did you mind adding a compile flag to save the time when FP8 is not needed? Thanks

@endurehero
Copy link
Author

endurehero commented Mar 1, 2025

awesome, did you mind adding a compile flag to save the time when FP8 is not needed? Thanks

Of course. Already Done

@beginlner
Copy link
Collaborator

beginlner commented Mar 1, 2025

Great work! However, I can’t merge this PR at the moment because, based on our tests, per-sequence kvcache scaling significantly reduces accuracy for MLA.

@endurehero
Copy link
Author

Great work! However, I can’t merge this PR at the moment because, based on our tests, per-sequence kvcache scaling significantly reduces accuracy for MLA.

What about the granularity of PerPageBlock? I can easily adapt it

@beginlner
Copy link
Collaborator

beginlner commented Mar 1, 2025

What about the granularity of PerPageBlock? I can easily adapt it

We think PerPageBlock is neither enough. kv_rope (64) needs to be bf16.

@endurehero
Copy link
Author

What about the granularity of PerPageBlock? I can easily adapt it

We think PerPageBlock is neither enough. kv_rope (64) needs to be bf16.

Got it!

@beginlner beginlner closed this Mar 11, 2025
@moses3017
Copy link

What about the granularity of PerPageBlock? I can easily adapt it

We think PerPageBlock is neither enough. kv_rope (64) needs to be bf16.

How about Qnope and Knope using 8-bit quantization, while Qrope and Krope maintain 16-bit data types?

@beginlner
Copy link
Collaborator

beginlner commented May 21, 2025

It's acceptable for Qnope and Knope to use per-(1 token × 128 channel) 8-bit quantization, while Qrope and Krope retain 16-bit precision.

@shinezyy
Copy link

shinezyy commented May 21, 2025

It's acceptable for Qnope and Knope to use per-(1 token × 128 channel) 8-bit quantization, while Qrope and Krope retain 16-bit precision.

The outliners in RoPE cache are also discussed in this paper https://arxiv.org/pdf/2502.01563

Can we add a hadamard transform right after RoPE to distribute outliners to multiple head dims? (https://arxiv.org/abs/2404.00456)

@TheTinyTeddy
Copy link

TheTinyTeddy commented Jun 3, 2025

It's acceptable for Qnope and Knope to use per-(1 token × 128 channel) 8-bit quantization, while Qrope and Krope retain 16-bit precision.

What precision should S×V be? BF16×BF16 or BF16×FP8 or FP8×FP8 per(1 token × 128 channel)?

@hypdeb
Copy link

hypdeb commented Jun 29, 2025

Great work! However, I can’t merge this PR at the moment because, based on our tests, per-sequence kvcache scaling significantly reduces accuracy for MLA.

@beginlner Hello there, on which workload did you observe these accuracy issues?

MatthewBonanni added a commit to MatthewBonanni/FlashMLA that referenced this pull request Aug 6, 2025
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
MatthewBonanni added a commit to MatthewBonanni/FlashMLA that referenced this pull request Aug 7, 2025
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
MatthewBonanni added a commit to MatthewBonanni/FlashMLA that referenced this pull request Aug 11, 2025
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
@MicroZHY
Copy link

Hi @endurehero,

Thank you very much for sharing this impressive FP8 implementation and for the detailed performance numbers!

I have two quick questions that would help me understand the design better:

  1. Branch difference
    Could you kindly clarify the relationship between will_fp8_mr and the earlier will_fp8 branch? I noticed both names appear in the commit history, and I’d love to know what motivated the new branch and which improvements or fixes it contains.

  2. Necessity of TransV
    I see that the new FP8 path relies on the TransV routine, which borrows the 64×64 shared-memory transpose from FlashAttention-3. Would it be possible to briefly explain why TransV is indispensable for FP8 correctness or performance? I’m curious whether the same result could be achieved with a different layout or if this is a hard requirement for the WGMMA pipeline.

Thanks again for your time and for open-sourcing this work!

@MicroZHY MicroZHY mentioned this pull request Aug 14, 2025
LucasWilkinson pushed a commit to vllm-project/FlashMLA that referenced this pull request Aug 18, 2025
* Add files from deepseek-ai#54

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* FP8 now extends base implementation

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Fix typo

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Update tests

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Add to build

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Fix installation

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Fix FLASH_MLA_DISABLE_FP8 flag

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Fix param matchup

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* typo

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Fix out dtype

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Fix IMA

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Extension name should be _flashmla_C

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Clean up

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Tighten FP8 error tolerance

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Add attribution to copied files

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Remove breakpoint

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

* Port cudagraph fix from #3

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>

---------

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants