Skip to content

[Triton] FP8 with FA v3 API#917

Closed
micmelesse wants to merge 12 commits intomainfrom
micmelesse/fp8
Closed

[Triton] FP8 with FA v3 API#917
micmelesse wants to merge 12 commits intomainfrom
micmelesse/fp8

Conversation

@micmelesse
Copy link
Contributor

Motivation

Add support for fp8 in Flash Attention

Technical Details

Modify existing code so that it confirms to the flash attention v3 api. A user provides fp8 values for q, k and v and their descale values.

Test Plan

update mha tests and bench code

Test Result

Submission Checklist

@micmelesse
Copy link
Contributor Author

micmelesse commented Aug 29, 2025

You can see examples on how to use the interface for fp8 by looking at the tests.

for examples of using fp8 with regular and paged attention see the tests at op_tests/triton_tests/test_mha.py .

For regular attention fp8, you will see code that looks like this

from aiter.ops.triton.mha_v3 import (
    flash_attn_func as flash_attn_func_v3,
)

# enable backward for fp8 using dequantized values
set_fp8_dequantized_backward(True)

# forward
triton_out = flash_attn_func_v3(
    q_fp8,
    k_fp8,
    v_fp8,
    softmax_scale=None,
    causal=CAUSAL,
    q_descale=q_descale,
    k_descale=k_descale,
    v_descale=v_descale,
)

# backward
triton_dq, triton_dk, triton_dv = torch.autograd.grad(
    triton_out, (q_fp8, k_fp8, v_fp8), do.clone()
)

Here is a small model trained on wikitext to test convergence.

combined_loss

for paged attention fp8 which is available in the inference api flash_attn_with_kvcache , you will see code that looks like this

from aiter.ops.triton.mha_v3 import (
    flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)

# forward
out_kernel = flash_attn_with_kvcache_v3(
            q_fp8,
            k_cache_fp8,
            v_cache_fp8,
            cache_seqlens=cache_seqlens,
            causal=causal,
            q_descale=q_descale,
            k_descale=k_descale,
            v_descale=v_descale,
            page_table=page_table,
        )

Fa V3 api

Compress fp8 work so far

pull cast out of torch function

e2e fp8 stub

emulate fa v3

ignore

remove example

clean up forward

save

fp8 backward

ignore train artifacts

just use return_attn_probs

match fa behvaior

save fa ref

add fa_ref

fix dropout bug

add link

optional fp8 p descale

rename to v3

fa v3

clean up

match backward

min diff

update varlen api

clean up FP8_P_DESCALE

update bench and test

lint

fix mha varlen bug

remove .gitignore

save

lint

remove skip

bring back skips
@dhonnappa-amd
Copy link
Collaborator

Jenkins CI skipped: Check lint failed. Exiting the entire job...

@micmelesse
Copy link
Contributor Author

I will reopen in a bit.

@micmelesse micmelesse closed this Sep 18, 2025
@micmelesse micmelesse reopened this Sep 22, 2025
@dhonnappa-amd
Copy link
Collaborator

Jenkins CI skipped: Required check(s) 'ruff_black' are missing. Exiting the entire job...

@dhonnappa-amd
Copy link
Collaborator

Jenkins CI skipped: Check lint failed. Exiting the entire job...

@micmelesse micmelesse closed this Sep 22, 2025
@micmelesse
Copy link
Contributor Author

Moved here, #1065

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants