Skip to content

Enable FP8 support for Flashinfer ROCm decode kernels on CDNA3#40

Merged
diptorupd merged 5 commits intoROCm:amd-integrationfrom
rtmadduri:feature/try-fp8
Nov 12, 2025
Merged

Enable FP8 support for Flashinfer ROCm decode kernels on CDNA3#40
diptorupd merged 5 commits intoROCm:amd-integrationfrom
rtmadduri:feature/try-fp8

Conversation

@rtmadduri
Copy link
Collaborator

This PR enables support for __hip_fp8_e4m3fnuz and __hip_fp8_e5m2 dtypes.

This PR adds -

  • PyTorch support for __hip_fp8 variants for both AOT and JIT
  • Additional utility conversion functions to move between __hip_fp8_e4m3fnuz, __hip_fp8_e5m2, __half, float
  • Modifications to the chunking logic to accommodate fp8 dtype
  • A new batch decode pytest for fp8.

Note: This PR does not add RoPE-Llama support for the __hip_fp8_* variants. This will be addressed in a different PR.

PyTest results - tests/test_batch_decode_kernels_hip_fp8.py

===================== 864 passed, 4 warnings in 55.39s ===================== 

Running the entire test suite - scripts/run_hip_tests.sh

=========17252 passed, 18 skipped, 12 warnings in 332.42s (0:05:32) =========

@demandal25 demandal25 self-requested a review November 10, 2025 15:08
@diptorupd diptorupd changed the title Enable FP8 support for Flashinfer ROCm on CDNA3 Enable FP8 support for Flashinfer ROCm decode kernels on CDNA3 Nov 10, 2025
Copy link
Collaborator

@demandal25 demandal25 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some minor comments and questions for clarifications.

@demandal25 demandal25 self-requested a review November 11, 2025 18:17
Copy link
Collaborator

@demandal25 demandal25 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments

FLASHINFER_ENABLE_FP8="OFF"
FLASHINFER_ENABLE_FP8_E4M3="OFF"
FLASHINFER_ENABLE_FP8_E5M2="OFF"
FLASHINFER_ENABLE_FP8="ON"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at Options.cmake setting FLASHINFER_ENABLE_FP8 sets the FLASHINFER_ENABLE_FP8_E4M3 and FLASHINFER_ENABLE_FP8_E5M2 to true. So, we should only use the FLASHINFER_ENABLE_FP8 flag.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be fixed later.

Signed-off-by: Debasis Mandal <Debasis.Mandal@amd.com>
@diptorupd diptorupd merged commit 1d29f90 into ROCm:amd-integration Nov 12, 2025
1 check passed
diptorupd pushed a commit that referenced this pull request Dec 5, 2025
This PR fixes some of the unit test failures that occur in Single
Decode. It also disables clang formatting of headers.
The clang format of headers causes compilation issues. The compiler is
unable to find `HIP WARP SYNC INTRINSICS` causing failures. Disabling
clang format fixes these issues

```
    Start 1: MathTest
1/6 Test #1: MathTest .........................   Passed    3.31 sec
    Start 2: PosEncTest
2/6 Test #2: PosEncTest .......................   Passed    3.36 sec
    Start 3: CascadeTest
3/6 Test #3: CascadeTest ......................   Passed    3.35 sec
    Start 4: PageTest
4/6 Test #4: PageTest .........................   Passed  114.08 sec
    Start 5: SingleDecodeTest
5/6 Test #5: SingleDecodeTest .................   Passed   35.22 sec
    Start 6: BatchDecodeTest
6/6 Test #6: BatchDecodeTest ..................   Passed  559.75 sec

100% tests passed, 0 tests failed out of 6

Total Test time (real) = 719.07 sec
```
diptorupd pushed a commit that referenced this pull request Dec 5, 2025
This PR enables support for `__hip_fp8_e4m3fnuz` and `__hip_fp8_e5m2` dtypes for the decode kernels

This PR adds - 

- PyTorch support for `__hip_fp8` variants for both AOT and JIT
- Additional utility conversion functions to move between
`__hip_fp8_e4m3fnuz`, `__hip_fp8_e5m2`, `__half`, `float`
- Modifications to the chunking logic to accommodate `fp8` dtype
- A new batch decode pytest for fp8.

Note: This PR does not add `RoPE-Llama` support for the `__hip_fp8_*` variants. This will be addressed in a different PR.


PyTest results - `tests/test_batch_decode_kernels_hip_fp8.py`

```
===================== 864 passed, 4 warnings in 55.39s ===================== 
```

Running the entire test suite - `scripts/run_hip_tests.sh`

```
=========17252 passed, 18 skipped, 12 warnings in 332.42s (0:05:32) =========
```

---------

Signed-off-by: Debasis Mandal <Debasis.Mandal@amd.com>
Co-authored-by: Debasis Mandal <Debasis.Mandal@amd.com>
zhenhantech pushed a commit to zhenhantech/flashinfer that referenced this pull request Jan 9, 2026
This PR fixes some of the unit test failures that occur in Single
Decode. It also disables clang formatting of headers.
The clang format of headers causes compilation issues. The compiler is
unable to find `HIP WARP SYNC INTRINSICS` causing failures. Disabling
clang format fixes these issues

```
    Start 1: MathTest
1/6 Test ROCm#1: MathTest .........................   Passed    3.31 sec
    Start 2: PosEncTest
2/6 Test ROCm#2: PosEncTest .......................   Passed    3.36 sec
    Start 3: CascadeTest
3/6 Test ROCm#3: CascadeTest ......................   Passed    3.35 sec
    Start 4: PageTest
4/6 Test ROCm#4: PageTest .........................   Passed  114.08 sec
    Start 5: SingleDecodeTest
5/6 Test ROCm#5: SingleDecodeTest .................   Passed   35.22 sec
    Start 6: BatchDecodeTest
6/6 Test ROCm#6: BatchDecodeTest ..................   Passed  559.75 sec

100% tests passed, 0 tests failed out of 6

Total Test time (real) = 719.07 sec
```
zhenhantech pushed a commit to zhenhantech/flashinfer that referenced this pull request Jan 9, 2026
This PR enables support for `__hip_fp8_e4m3fnuz` and `__hip_fp8_e5m2` dtypes for the decode kernels

This PR adds - 

- PyTorch support for `__hip_fp8` variants for both AOT and JIT
- Additional utility conversion functions to move between
`__hip_fp8_e4m3fnuz`, `__hip_fp8_e5m2`, `__half`, `float`
- Modifications to the chunking logic to accommodate `fp8` dtype
- A new batch decode pytest for fp8.

Note: This PR does not add `RoPE-Llama` support for the `__hip_fp8_*` variants. This will be addressed in a different PR.


PyTest results - `tests/test_batch_decode_kernels_hip_fp8.py`

```
===================== 864 passed, 4 warnings in 55.39s ===================== 
```

Running the entire test suite - `scripts/run_hip_tests.sh`

```
=========17252 passed, 18 skipped, 12 warnings in 332.42s (0:05:32) =========
```

---------

Signed-off-by: Debasis Mandal <Debasis.Mandal@amd.com>
Co-authored-by: Debasis Mandal <Debasis.Mandal@amd.com>
diptorupd pushed a commit to diptorupd/flashinfer that referenced this pull request Jan 28, 2026
This PR fixes some of the unit test failures that occur in Single
Decode. It also disables clang formatting of headers.
The clang format of headers causes compilation issues. The compiler is
unable to find `HIP WARP SYNC INTRINSICS` causing failures. Disabling
clang format fixes these issues

```
    Start 1: MathTest
1/6 Test #1: MathTest .........................   Passed    3.31 sec
    Start 2: PosEncTest
2/6 Test #2: PosEncTest .......................   Passed    3.36 sec
    Start 3: CascadeTest
3/6 Test #3: CascadeTest ......................   Passed    3.35 sec
    Start 4: PageTest
4/6 Test #4: PageTest .........................   Passed  114.08 sec
    Start 5: SingleDecodeTest
5/6 Test #5: SingleDecodeTest .................   Passed   35.22 sec
    Start 6: BatchDecodeTest
6/6 Test #6: BatchDecodeTest ..................   Passed  559.75 sec

100% tests passed, 0 tests failed out of 6

Total Test time (real) = 719.07 sec
```
diptorupd pushed a commit to diptorupd/flashinfer that referenced this pull request Jan 28, 2026
This PR enables support for `__hip_fp8_e4m3fnuz` and `__hip_fp8_e5m2` dtypes for the decode kernels

This PR adds - 

- PyTorch support for `__hip_fp8` variants for both AOT and JIT
- Additional utility conversion functions to move between
`__hip_fp8_e4m3fnuz`, `__hip_fp8_e5m2`, `__half`, `float`
- Modifications to the chunking logic to accommodate `fp8` dtype
- A new batch decode pytest for fp8.

Note: This PR does not add `RoPE-Llama` support for the `__hip_fp8_*` variants. This will be addressed in a different PR.


PyTest results - `tests/test_batch_decode_kernels_hip_fp8.py`

```
===================== 864 passed, 4 warnings in 55.39s ===================== 
```

Running the entire test suite - `scripts/run_hip_tests.sh`

```
=========17252 passed, 18 skipped, 12 warnings in 332.42s (0:05:32) =========
```

---------

Signed-off-by: Debasis Mandal <Debasis.Mandal@amd.com>
Co-authored-by: Debasis Mandal <Debasis.Mandal@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants