-
Notifications
You must be signed in to change notification settings - Fork 540
[common] Remove kvpacked and qkvpacked attention functions for every kernel type. #2287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, no comments
|
I think this is similar to #2272 :) Yes, Jax needs a bit of fixing in order to get its attention working. |
|
Could you add the deprecation note for these qkvpacked/kvpacked APIs as we discussed offline please? Thanks. |
Signed-off-by: Pawel Gadzinski <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Refactors fused attention code by eliminating duplicate kvpacked and qkvpacked kernel functions, moving pointer arithmetic into helper functions at the common API layer.
Key Changes
- Adds helper functions (
make_tensor_view,calculate_qkv_stride,calculate_kv_stride) to handle tensor unpacking - Deprecated packed API functions now unpack QKV/KV tensors and call non-packed kernel implementations
- Removes ~1400 lines of duplicate kernel code from arbitrary_seqlen implementation
- Adds deprecation warnings to packed API functions in header file
Issues Found
- Critical bug in KV-packed max512 stride calculation: Lines 840 and 977 incorrectly use
h_qinstead ofh_kvfor calculating byte offset between K and V tensors. This will cause incorrect memory access in Grouped Query Attention (GQA) scenarios where h_q ≠ h_kv.
Confidence Score: 2/5
- This PR contains critical bugs in stride calculations that will cause memory corruption in GQA scenarios
- The refactoring approach is sound and removes significant code duplication, but contains a critical logic error in KV-packed tensor stride calculations (lines 840, 977) where
h_qis used instead ofh_kv. This bug exists in both forward and backward passes for max512 kernels and will cause incorrect memory access when h_q ≠ h_kv (Grouped Query Attention). While the bug may not manifest in standard attention where h_q == h_kv, it represents a correctness issue that must be fixed before merge. - transformer_engine/common/fused_attn/fused_attn.cpp requires immediate attention - fix stride calculations on lines 840 and 977
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 2/5 | Refactors packed attention functions to use helper functions and tensor views. Contains critical bugs in KV-packed stride calculations (lines 840, 977) where h_q is used instead of h_kv. |
| transformer_engine/common/include/transformer_engine/fused_attn.h | 5/5 | Adds deprecation warnings to packed QKV/KV API functions, directing users to use separate Q, K, V tensors instead. |
Sequence Diagram
sequenceDiagram
participant User as JAX/PyTorch
participant API as nvte_fused_attn_*_kvpacked (Deprecated)
participant Helper as Helper Functions
participant Kernel as fused_attn_max_512_* (Non-packed)
Note over User,Kernel: BEFORE: Separate functions for each kernel type
User->>API: nvte_fused_attn_fwd_kvpacked(KV_packed)
API->>Kernel: fused_attn_max_512_fwd_kvpacked(KV_packed)
Note over Kernel: Unpacks KV internally<br/>stride = 2*h*d
Kernel-->>API: result
API-->>User: result
Note over User,Kernel: AFTER: Unified using helper functions
User->>API: nvte_fused_attn_fwd_kvpacked(KV_packed)
API->>Helper: calculate_kv_stride(layout, dtype, h_kv, d)
Helper-->>API: stride (in bytes)
API->>Helper: make_tensor_view(KV, shape, offset=0)
Helper-->>API: K_view
API->>Helper: make_tensor_view(KV, shape, offset=stride)
Helper-->>API: V_view
API->>Kernel: fused_attn_max_512_fwd(Q, K_view, V_view)
Note over Kernel: Works with unpacked views
Kernel-->>API: result
API-->>User: result
2 files reviewed, 2 comments
| bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, | ||
| input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); | ||
| // Unpack KV and call the non-packed function | ||
| size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: stride calculation uses h_q but should use h_kv for KV-packed tensors
For KV-packed layout BS2HD, the memory contains K[h_kv, d] followed by V[h_kv, d]. The stride to skip from K to V should be based on h_kv, not h_q (which can differ in GQA scenarios).
| size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD | |
| size_t stride = 2 * h_kv * d; // For max512, KV layout is BS2HD or SB2HD |
| input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); | ||
|
|
||
| // Unpack KV and dKV and call the non-packed function | ||
| size_t stride = 2 * h_q * d; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: stride calculation uses h_q but should use h_kv for KV-packed tensors (backward pass)
Same issue as in forward pass - the stride should be based on h_kv for KV-packed tensors.
| size_t stride = 2 * h_q * d; | |
| size_t stride = 2 * h_kv * d; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good point - this is some leftover from my previous change. Weird that tests didn't catch it. Fixed.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR consolidates fused attention implementations by removing duplicate QKV-packed and KV-packed kernel variants across three backend types (max512, arbitrary seqlen, FP8). The refactoring moves tensor unpacking logic from kernel-level to the common API layer in fused_attn.cpp, reducing ~1,388 lines of duplicated code.
Key changes:
- Removes
*_qkvpackedand*_kvpackedfunctions from all three kernel types - Adds helper functions for stride calculation and tensor view creation
- Unpacks QKV/KV tensors in the common API layer before calling unified kernel functions
- Deprecates packed QKV/KV APIs in the public header
Critical issue found:
- KV-packed max512 paths (lines 842, 979 in
fused_attn.cpp) useh_qinstead ofh_kvfor stride calculation, causing incorrect pointer arithmetic in Grouped Query Attention (GQA) scenarios whereh_q != h_kv
Confidence Score: 2/5
- Critical stride calculation bug will cause memory corruption in GQA with max512 backend
- Good refactoring approach, but incorrect stride calculation using
h_qinstead ofh_kvin max512 KV-packed paths will access wrong memory addresses when query and key/value have different head counts (GQA scenarios) transformer_engine/common/fused_attn/fused_attn.cpplines 842 and 979 - incorrect stride calculation for KV-packed tensors
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 2/5 | Major refactoring that consolidates packed variants into unpacking logic. Critical bug: KV-packed max512 forward/backward use incorrect stride calculation (h_q instead of h_kv) for GQA scenarios. |
| transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu | 5/5 | Removes duplicate QKV/KV-packed function implementations, keeping only the unpacked variants. Clean removal of ~530 lines of duplicate code. |
| transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu | 5/5 | Removes QKV/KV-packed implementations (~264 lines), retaining only unpacked function. Good code deduplication. |
| transformer_engine/common/fused_attn/fused_attn_fp8.cu | 5/5 | Removes FP8 packed variant implementations (~418 lines). Good deduplication, shifts unpacking to higher level. |
Sequence Diagram
sequenceDiagram
participant User as JAX/PyTorch User
participant API as nvte_fused_attn_fwd_kvpacked<br/>(common API - fused_attn.cpp)
participant Helper as Helper Functions<br/>(calculate_kv_stride, etc.)
participant Kernel as Kernel Implementation<br/>(fused_attn_*_fwd)
User->>API: Call with packed KV tensor<br/>[b, s, 2, h_kv, d]
Note over API: Determine backend type<br/>(max512/arbitrary/fp8)
alt max512 backend
API->>Helper: Calculate stride for unpacking
Note over Helper: Should use h_kv, not h_q!<br/>(Bug at lines 842, 979)
Helper-->>API: stride = 2 * h_q * d (WRONG)
else arbitrary/fp8 backend
API->>Helper: calculate_kv_stride(h_kv, d)
Helper-->>API: stride = (bits * h_kv * d) / 8 (CORRECT)
end
API->>API: make_tensor_view(KV, shape, 0)<br/>→ K_view
API->>API: make_tensor_view(KV, shape, stride)<br/>→ V_view
API->>Kernel: fused_attn_*_fwd(Q, K_view, V_view)
Note over Kernel: Operates on unpacked<br/>Q, K, V tensors
Kernel-->>API: Output O
API-->>User: Return attention output
8 files reviewed, no comments
Signed-off-by: Pawel Gadzinski <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully refactors fused attention functions by eliminating code duplication across three kernel types (arbitrary seqlen, max512, and fp8). The key changes include:
- New helper functions: Added
make_tensor_view(),calculate_qkv_stride(),calculate_kv_stride(), and shape calculation helpers in an anonymous namespace - Unpacking moved up: KV/QKV tensor unpacking now happens in the common API layer (
nvte_fused_attn_fwd_qkvpacked,nvte_fused_attn_fwd_kvpacked) instead of within each kernel-specific function - Kernel simplification: Removed
_qkvpackedand_kvpackedvariants of kernel functions (fused_attn_max_512_fwd_qkvpacked,fused_attn_max_512_fwd_kvpacked, etc.), reducing code duplication - Bug fix: Correctly uses
h_kvinstead ofh_qfor KV-packed stride calculations, fixing a bug in GQA (Grouped Query Attention) scenarios where the number of KV heads differs from Q heads - APIs marked deprecated: Added deprecation notices to
nvte_fused_attn_fwd_qkvpackedandnvte_fused_attn_fwd_kvpackedfunctions
The refactoring maintains functional equivalence while improving code maintainability. The stride calculations are now correct for all layout types (NVTE_HD_2HD, NVTE_HD_H2D, NVTE_3HD, NVTE_H3D).
Confidence Score: 5/5
- This PR is safe to merge with high confidence - it's a well-executed refactoring that improves code quality and fixes stride calculation bugs
- The refactoring is clean and well-structured with proper helper functions. The stride calculations correctly use
h_kvfor KV-packed tensors, fixing potential bugs in GQA scenarios. All unpacking logic has been centralized with consistent patterns across forward/backward passes and all kernel types. The changes maintain backward compatibility by keeping deprecated API functions. - No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/fused_attn/fused_attn.cpp | 5/5 | Major refactoring that removes code duplication by moving KV/QKV unpacking logic from kernel-specific functions to the common API layer. Correctly fixes stride calculation bugs for GQA scenarios where h_q != h_kv. |
Sequence Diagram
sequenceDiagram
participant User as JAX/PyTorch API
participant CommonAPI as nvte_fused_attn_fwd_kvpacked
participant Helpers as Helper Functions
participant Kernel as fused_attn_max_512_fwd
User->>CommonAPI: Call with packed KV tensor
CommonAPI->>Helpers: calculate_kv_stride(layout_group, dtype, h_kv, d)
Helpers-->>CommonAPI: stride (bytes)
CommonAPI->>Helpers: calculate_kv_unpacked_shape(KV, layout, h_kv, d)
Helpers-->>CommonAPI: unpacked_shape
CommonAPI->>Helpers: make_tensor_view(KV, shape, 0)
Helpers-->>CommonAPI: K_view
CommonAPI->>Helpers: make_tensor_view(KV, shape, stride)
Helpers-->>CommonAPI: V_view
CommonAPI->>Kernel: fused_attn_max_512_fwd(Q, K_view, V_view, ...)
Kernel-->>CommonAPI: output
CommonAPI-->>User: return output
1 file reviewed, no comments
|
/te-ci jax |
|
|
||
| /*! \brief Compute dot product attention with packed QKV input. | ||
| * | ||
| * \warning This API is **deprecated**. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would \deprecated be a better marking for this message in the document?
Could we add something like this to give user warnings at compile time as well?
[[deprecated("nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() instead.")]]
void nvte_fused_attn_fwd_qkvpacked() {
}
Description
There are 3 variants of fused_attention functions: for separate QKV, KV packed and QKV packed, which differ only by pointers to qkv. This results in code duplication for each type of the fused attention kernel: arbitrary seqlen, max 512 and fp8. This PR deduplicates the code and moves pointer computation one abstraction layer - from the functions like
fused_attn_max_512_fwd_qkvpackedinto the functions likenvte_fused_attn_fwd_qkvpackedin common c++ api.These packed versions of common attention api functions are used by JAX, so I think running JAX CI is good test of that changes. PyTorch uses only non-packed function.
Type of change
Checklist: