Add Flash Attention 4 (CuTe DSL) Support#42404
Add Flash Attention 4 (CuTe DSL) Support#42404sambhavnoobcoder wants to merge 8 commits intohuggingface:mainfrom
Conversation
vasqu
left a comment
There was a problem hiding this comment.
Hey sorry about this but #42435 will likely supersede this.
The current PR here just adds too many custom things that don't align within the library itself, e.g. a lot of random test files. Since this is a probably more important feature, I attempted a first version which you can also try.
|
Hi @vasqu, |
|
Hey @sambhavnoobcoder, I'm really sorry about this but this kind of PR is more involved and needs more careful checks on our side. It requires a lot of niche knowledge specific to our library and I still encountered a lot of edge cases I wouldn't have found if I didn't do it myself. You are still welcome to expand and help on the other PR! For learning experiences / first PRs, I recommend looking into |
Problem Statement
Flash Attention 4 represents a significant architectural shift in the flash-attention package:
flash_attn.cutesubmodule instead of the mainflash_attnpackageflash_attn_varlen_funchas a different signature - it does NOT acceptmax_seqlen_qandmax_seqlen_kparameters (calculates them internally fromcu_seqlens)learnable_sink,num_splits, andpack_gqadropout_pandalibi_slopessupportWithout explicit FA4 support, users cannot leverage these improvements even when they have compatible hardware and the flash-attn package with CuTe DSL installed.
Solution Design
The solution maintains full backward compatibility while adding FA4 support through:
1. Detection Layer
Added
is_flash_attn_4_available()function that checks:flash_attn.cutesubmodule2. Priority-Based Auto-Selection
When
attn_implementation=None, the selection order is:FA4 gets highest priority on compatible hardware for optimal performance.
3. Runtime Introspection
Created
_is_using_fa4()helper that uses function signature inspection to detect FA4 vs FA2/FA3 at runtime. This enables conditional code paths without hardcoded version checks.4. Conditional Varlen Calls
Modified two critical call sites in
_flash_attention_forward()to conditionally pass parameters:max_seqlen_qandmax_seqlen_k(calculates internally)5. Parameter Support
Extended
_process_flash_attention_kwargs()to handle FA4-specific parameters, with automatic filtering based on introspection to maintain compatibility across versions.6. Registration
Registered
flash_attention_4inAttentionInterface._global_mappingto enable explicit selection viaattn_implementation="flash_attention_4".Implementation Details
Core Changes
Detection and Import
flash_attn.cutesubmoduleIntegration Layer
Interface Registration
Testing Infrastructure
New Files
Test Suite
Comprehensive test coverage including:
Validation Script
Quick validation script for SSH GPU access that checks:
Usage Examples
Demonstrates:
Testing Status
Automated Checks
Created and ran comprehensive verification script checking:
All 14 core integration checks passed, plus 7 additional file checks passed.
Pending Testing (Requires GPU)
GPU Validation Required
Due to lack of CUDA GPU access during development, the following tests are pending:
Basic Functionality
Integration Tests
Real-World Usage
Hardware Requirements
Known Limitations
dropout_pparameter - training with dropout will automatically fall back to FA2/eagersoftcap != 0.0may be restricted during backward passAll limitations are handled gracefully via automatic fallback.
Usage
Explicit FA4 Selection
Users can explicitly request FA4 when loading models.
Auto-Selection (Recommended)
When no attention implementation is specified, transformers will automatically select the best available implementation, with FA4 receiving highest priority on compatible hardware.
Check Availability
Users can check if FA4 is available using the
is_flash_attn_4_available()function.fixes : #42405