Skip to content

[ROCM] Add support with Infinity Cache (LLC) awareness for improved performance #2147

Closed
tianwyan wants to merge 6 commits intoDao-AILab:mainfrom
ROCm:tianwyan/navi_experiment
Closed

[ROCM] Add support with Infinity Cache (LLC) awareness for improved performance #2147
tianwyan wants to merge 6 commits intoDao-AILab:mainfrom
ROCm:tianwyan/navi_experiment

Conversation

@tianwyan
Copy link
Copy Markdown

@tianwyan tianwyan commented Jan 7, 2026

Motivation

This PR enables Flash Attention Triton support for AMD RDNA3 (Navi) GPUs, specifically targeting the gfx1100 architecture. The goal is to bring Flash Attention performance optimizations to consumer-grade AMD GPUs while leveraging the unique Infinity Cache (LLC) architecture for improved memory throughput.

Technical Details

New Architecture Support:

  • Added gfx1100 (RDNA3/Navi 31) to the supported GPU architectures in the Triton Flash Attention backend

Performance Optimizations:

  • Implemented Infinity Cache (LLC) awareness to optimize memory access patterns and reduce DRAM bandwidth pressure
  • Enabled exp2 instruction by default for faster exponential calculations on RDNA3
  • Added additional Triton autotuning configurations optimized for Navi's wavefront and cache characteristics

Code Cleanup:

  • Renamed "L2 cache" terminology to "Infinity Cache (LLC)" throughout the codebase to accurately reflect AMD's cache hierarchy and avoid confusion with the traditional L2 cache

Test Plan

  • Functional testing on AMD Radeon RX 7900 XTX (gfx1100)
  • Verified Flash Attention forward pass correctness against reference implementation
  • Benchmarked memory bandwidth utilization with and without LLC awareness

Test Result

  • All existing Triton Flash Attention tests pass on gfx1100
  • ~2-4x performance improvement with LLC-aware implementation on memory-bound attention workloads
  • LLC awareness significantly reduces DRAM bandwidth pressure by better utilizing the 96MB Infinity Cache on RDNA3

@tianwyan
Copy link
Copy Markdown
Author

tianwyan commented Jan 7, 2026

@tridao Would you be interested in reviewing this for potential upstream to Dao-AILab/flash-attention? This adds RDNA3/gfx1100 support with ~2x performance improvement via AMD LLC awareness.

@tridao
Copy link
Copy Markdown
Member

tridao commented Jan 13, 2026

Cc @rocking5566 are there folks who can review this PR?

@micmelesse
Copy link
Copy Markdown
Collaborator

micmelesse commented Jan 13, 2026

@tridao Me and @tianwyan are coordinating on this pr. I will review this pr soon. I have a pr that is coming which enables a bunch of features so we have to coordinate to order things.

@rocking5566
Copy link
Copy Markdown
Contributor

@tridao
@micmelesse is our Triton backend folk. He will review this PR.

@tianwyan
Copy link
Copy Markdown
Author

as discussed with @micmelesse , I'll rebase my current PR to coordinate with his incoming. :)

@micmelesse
Copy link
Copy Markdown
Collaborator

The PR I mentioned is up here, #2178.

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Jan 23, 2026

RDNA4 (gfx1200 in my case) support would be great. I made some local changes like:
In setup.py:

def validate_and_update_archs(archs):
    # List of allowed architectures
    allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1100", "gfx1200"]

And in flash_attn/flash_attn_triton_amd/l2_cache_aware.py:

AMD_LLC_CACHE_SIZES: Dict[str, int] = {
    "gfx1100": 96 * 1024 * 1024,   # RX 7900 XTX/XT - 96 MB Infinity Cache
    "gfx1101": 64 * 1024 * 1024,   # RX 7800 XT - 64 MB Infinity Cache
    "gfx1102": 32 * 1024 * 1024,   # RX 7600 - 32 MB Infinity Cache
    "gfx1200": 32 * 1024 * 1024,   # RX 9060 XT - 32 MB Infinity Cache
}
    known_cus = {
        "gfx1100": 96,   # RX 7900 XTX
        "gfx1101": 60,   # RX 7800 XT  
        "gfx1102": 32,   # RX 7600
        "gfx1200": 32,   # RX 9600 XT
    }

I think the autotune configs are also suitable for RDNA4.
Do I need to make any additional changes to have gfx1200 fully supported as well? This way, gfx1201 would also be easy to include in the script. Or would these require some fundamental changes?

Edit:
Based on my testing using this script, I can confirm that after making the above changes, this PR becomes compatible with gfx1200. My output is:

✓ flash_attn imported successfully

=== GPU Info ===
Device: AMD Radeon RX 9060 XT
Architecture: gfx1200
================

Test config: batch=1, seqlen=8192, nheads=32, head_dim=128
Total K,V memory: 128.0 MB
LLC size (gfx1200): 32 MB
Should trigger head grouping: True

Running flash attention (check output above for head grouping info)...


=== Infinity Cache (LLC) Aware Head Grouping ===
GPU: gfx1200 (32 CUs)
Infinity Cache (LLC): 32.0 MB
Heads: 32, SeqLen: 8192, HeadDim: 128
Total K,V Memory: 128.0 MB
LLC Ratio: 4.00x
Should Group: True
Group Size: 12 heads (3 groups)
K,V per Group: 48.0 MB
================================================

[L2 Head Grouping] Processing 32 heads in groups of 12

✓ Flash attention succeeded
Output shape: torch.Size([1, 8192, 32, 128])
✓ Output is valid (no NaN)

=== Result ===
If you see 'Infinity Cache (LLC) Aware Head Grouping' output above,
then the PR features are working on gfx1200.
If not, it's using fallback FA-2 without LLC awareness.

Furthermore, this second script confirms the improved performance on gfx1200 thanks to LLC awareness:

GPU: AMD Radeon RX 9060 XT
Architecture: gfx1200

=== WITHOUT LLC-Aware Head Grouping (Original FA-2) ===
Average time: 39.226 ms
Throughput: 25.49 iter/s

=== WITH LLC-Aware Head Grouping (PR) ===
Average time: 19.667 ms
Throughput: 50.85 iter/s

=== Results ===
Original FA-2: 39.226 ms
With LLC-aware: 19.667 ms
Speedup: 1.99x
✓ PR is 1.99x faster!

Therefore, I think RDNA4 support as a whole could be added to this PR.

@tianwyan
Copy link
Copy Markdown
Author

RDNA4 (gfx1200 in my case) support would be great. I made some local changes like: In setup.py:

def validate_and_update_archs(archs):
    # List of allowed architectures
    allowed_archs = ["native", "gfx90a", "gfx950", "gfx942", "gfx1100", "gfx1200"]

And in flash_attn/flash_attn_triton_amd/l2_cache_aware.py:

AMD_LLC_CACHE_SIZES: Dict[str, int] = {
    "gfx1100": 96 * 1024 * 1024,   # RX 7900 XTX/XT - 96 MB Infinity Cache
    "gfx1101": 64 * 1024 * 1024,   # RX 7800 XT - 64 MB Infinity Cache
    "gfx1102": 32 * 1024 * 1024,   # RX 7600 - 32 MB Infinity Cache
    "gfx1200": 32 * 1024 * 1024,   # RX 9060 XT - 32 MB Infinity Cache
}
    known_cus = {
        "gfx1100": 96,   # RX 7900 XTX
        "gfx1101": 60,   # RX 7800 XT  
        "gfx1102": 32,   # RX 7600
        "gfx1200": 32,   # RX 9600 XT
    }

I think the autotune configs are also suitable for RDNA4. Do I need to make any additional changes to have gfx1200 fully supported as well? This way, gfx1201 would also be easy to include in the script. Or would these require some fundamental changes?

Edit: Based on my testing using this script, I can confirm that after making the above changes, this PR becomes compatible with gfx1200. My output is:

✓ flash_attn imported successfully

=== GPU Info ===
Device: AMD Radeon RX 9060 XT
Architecture: gfx1200
================

Test config: batch=1, seqlen=8192, nheads=32, head_dim=128
Total K,V memory: 128.0 MB
LLC size (gfx1200): 32 MB
Should trigger head grouping: True

Running flash attention (check output above for head grouping info)...


=== Infinity Cache (LLC) Aware Head Grouping ===
GPU: gfx1200 (32 CUs)
Infinity Cache (LLC): 32.0 MB
Heads: 32, SeqLen: 8192, HeadDim: 128
Total K,V Memory: 128.0 MB
LLC Ratio: 4.00x
Should Group: True
Group Size: 12 heads (3 groups)
K,V per Group: 48.0 MB
================================================

[L2 Head Grouping] Processing 32 heads in groups of 12

✓ Flash attention succeeded
Output shape: torch.Size([1, 8192, 32, 128])
✓ Output is valid (no NaN)

=== Result ===
If you see 'Infinity Cache (LLC) Aware Head Grouping' output above,
then the PR features are working on gfx1200.
If not, it's using fallback FA-2 without LLC awareness.

Furthermore, this second script confirms the improved performance on gfx1200 thanks to LLC awareness:

GPU: AMD Radeon RX 9060 XT
Architecture: gfx1200

=== WITHOUT LLC-Aware Head Grouping (Original FA-2) ===
Average time: 39.226 ms
Throughput: 25.49 iter/s

=== WITH LLC-Aware Head Grouping (PR) ===
Average time: 19.667 ms
Throughput: 50.85 iter/s

=== Results ===
Original FA-2: 39.226 ms
With LLC-aware: 19.667 ms
Speedup: 1.99x
✓ PR is 1.99x faster!

Therefore, I think RDNA4 support as a whole could be added to this PR.

thanks for the information! a new PR with LLC-aware head grouping will be created soon which is going to be rebased on #2178

@micmelesse
Copy link
Copy Markdown
Collaborator

#2178 is merged. @tianwyan will rebase this pr. We will then discuss and ping here when ready to merge.

@tianwyan
Copy link
Copy Markdown
Author

#2178 is merged. @tianwyan will rebase this pr. We will then discuss and ping here when ready to merge.

The rebased PR is #2217 @micmelesse @tridao

@tianwyan
Copy link
Copy Markdown
Author

the current PR is going to be closed, please go to #2217

@tianwyan tianwyan closed this Jan 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants