Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Oct 20, 2025

Description

There are 3 variants of fused_attention functions: for separate QKV, KV packed and QKV packed, which differ only by pointers to qkv. This results in code duplication for each type of the fused attention kernel: arbitrary seqlen, max 512 and fp8. This PR deduplicates the code and moves pointer computation one abstraction layer - from the functions like fused_attn_max_512_fwd_qkvpacked into the functions like nvte_fused_attn_fwd_qkvpacked in common c++ api.

These packed versions of common attention api functions are used by JAX, so I think running JAX CI is good test of that changes. PyTorch uses only non-packed function.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@pggPL pggPL marked this pull request as ready for review October 21, 2025 10:54
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 21, 2025

/te-ci jax

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@cyanguwa
Copy link
Collaborator

I think this is similar to #2272 :) Yes, Jax needs a bit of fixing in order to get its attention working.

@pggPL pggPL requested a review from cyanguwa November 3, 2025 17:32
@cyanguwa
Copy link
Collaborator

cyanguwa commented Nov 4, 2025

Could you add the deprecation note for these qkvpacked/kvpacked APIs as we discussed offline please? Thanks.

Signed-off-by: Pawel Gadzinski <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactors fused attention code by eliminating duplicate kvpacked and qkvpacked kernel functions, moving pointer arithmetic into helper functions at the common API layer.

Key Changes

  • Adds helper functions (make_tensor_view, calculate_qkv_stride, calculate_kv_stride) to handle tensor unpacking
  • Deprecated packed API functions now unpack QKV/KV tensors and call non-packed kernel implementations
  • Removes ~1400 lines of duplicate kernel code from arbitrary_seqlen implementation
  • Adds deprecation warnings to packed API functions in header file

Issues Found

  • Critical bug in KV-packed max512 stride calculation: Lines 840 and 977 incorrectly use h_q instead of h_kv for calculating byte offset between K and V tensors. This will cause incorrect memory access in Grouped Query Attention (GQA) scenarios where h_q ≠ h_kv.

Confidence Score: 2/5

  • This PR contains critical bugs in stride calculations that will cause memory corruption in GQA scenarios
  • The refactoring approach is sound and removes significant code duplication, but contains a critical logic error in KV-packed tensor stride calculations (lines 840, 977) where h_q is used instead of h_kv. This bug exists in both forward and backward passes for max512 kernels and will cause incorrect memory access when h_q ≠ h_kv (Grouped Query Attention). While the bug may not manifest in standard attention where h_q == h_kv, it represents a correctness issue that must be fixed before merge.
  • transformer_engine/common/fused_attn/fused_attn.cpp requires immediate attention - fix stride calculations on lines 840 and 977

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 2/5 Refactors packed attention functions to use helper functions and tensor views. Contains critical bugs in KV-packed stride calculations (lines 840, 977) where h_q is used instead of h_kv.
transformer_engine/common/include/transformer_engine/fused_attn.h 5/5 Adds deprecation warnings to packed QKV/KV API functions, directing users to use separate Q, K, V tensors instead.

Sequence Diagram

sequenceDiagram
    participant User as JAX/PyTorch
    participant API as nvte_fused_attn_*_kvpacked (Deprecated)
    participant Helper as Helper Functions
    participant Kernel as fused_attn_max_512_* (Non-packed)

    Note over User,Kernel: BEFORE: Separate functions for each kernel type
    User->>API: nvte_fused_attn_fwd_kvpacked(KV_packed)
    API->>Kernel: fused_attn_max_512_fwd_kvpacked(KV_packed)
    Note over Kernel: Unpacks KV internally<br/>stride = 2*h*d
    Kernel-->>API: result
    API-->>User: result

    Note over User,Kernel: AFTER: Unified using helper functions
    User->>API: nvte_fused_attn_fwd_kvpacked(KV_packed)
    API->>Helper: calculate_kv_stride(layout, dtype, h_kv, d)
    Helper-->>API: stride (in bytes)
    API->>Helper: make_tensor_view(KV, shape, offset=0)
    Helper-->>API: K_view
    API->>Helper: make_tensor_view(KV, shape, offset=stride)
    Helper-->>API: V_view
    API->>Kernel: fused_attn_max_512_fwd(Q, K_view, V_view)
    Note over Kernel: Works with unpacked views
    Kernel-->>API: result
    API-->>User: result
Loading

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors,
input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
// Unpack KV and call the non-packed function
size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: stride calculation uses h_q but should use h_kv for KV-packed tensors

For KV-packed layout BS2HD, the memory contains K[h_kv, d] followed by V[h_kv, d]. The stride to skip from K to V should be based on h_kv, not h_q (which can differ in GQA scenarios).

Suggested change
size_t stride = 2 * h_q * d; // For max512, KV layout is BS2HD or SB2HD
size_t stride = 2 * h_kv * d; // For max512, KV layout is BS2HD or SB2HD

input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle);

// Unpack KV and dKV and call the non-packed function
size_t stride = 2 * h_q * d;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: stride calculation uses h_q but should use h_kv for KV-packed tensors (backward pass)

Same issue as in forward pass - the stride should be based on h_kv for KV-packed tensors.

Suggested change
size_t stride = 2 * h_q * d;
size_t stride = 2 * h_kv * d;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good point - this is some leftover from my previous change. Weird that tests didn't catch it. Fixed.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR consolidates fused attention implementations by removing duplicate QKV-packed and KV-packed kernel variants across three backend types (max512, arbitrary seqlen, FP8). The refactoring moves tensor unpacking logic from kernel-level to the common API layer in fused_attn.cpp, reducing ~1,388 lines of duplicated code.

Key changes:

  • Removes *_qkvpacked and *_kvpacked functions from all three kernel types
  • Adds helper functions for stride calculation and tensor view creation
  • Unpacks QKV/KV tensors in the common API layer before calling unified kernel functions
  • Deprecates packed QKV/KV APIs in the public header

Critical issue found:

  • KV-packed max512 paths (lines 842, 979 in fused_attn.cpp) use h_q instead of h_kv for stride calculation, causing incorrect pointer arithmetic in Grouped Query Attention (GQA) scenarios where h_q != h_kv

Confidence Score: 2/5

  • Critical stride calculation bug will cause memory corruption in GQA with max512 backend
  • Good refactoring approach, but incorrect stride calculation using h_q instead of h_kv in max512 KV-packed paths will access wrong memory addresses when query and key/value have different head counts (GQA scenarios)
  • transformer_engine/common/fused_attn/fused_attn.cpp lines 842 and 979 - incorrect stride calculation for KV-packed tensors

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 2/5 Major refactoring that consolidates packed variants into unpacking logic. Critical bug: KV-packed max512 forward/backward use incorrect stride calculation (h_q instead of h_kv) for GQA scenarios.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 5/5 Removes duplicate QKV/KV-packed function implementations, keeping only the unpacked variants. Clean removal of ~530 lines of duplicate code.
transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu 5/5 Removes QKV/KV-packed implementations (~264 lines), retaining only unpacked function. Good code deduplication.
transformer_engine/common/fused_attn/fused_attn_fp8.cu 5/5 Removes FP8 packed variant implementations (~418 lines). Good deduplication, shifts unpacking to higher level.

Sequence Diagram

sequenceDiagram
    participant User as JAX/PyTorch User
    participant API as nvte_fused_attn_fwd_kvpacked<br/>(common API - fused_attn.cpp)
    participant Helper as Helper Functions<br/>(calculate_kv_stride, etc.)
    participant Kernel as Kernel Implementation<br/>(fused_attn_*_fwd)
    
    User->>API: Call with packed KV tensor<br/>[b, s, 2, h_kv, d]
    Note over API: Determine backend type<br/>(max512/arbitrary/fp8)
    
    alt max512 backend
        API->>Helper: Calculate stride for unpacking
        Note over Helper: Should use h_kv, not h_q!<br/>(Bug at lines 842, 979)
        Helper-->>API: stride = 2 * h_q * d (WRONG)
    else arbitrary/fp8 backend  
        API->>Helper: calculate_kv_stride(h_kv, d)
        Helper-->>API: stride = (bits * h_kv * d) / 8 (CORRECT)
    end
    
    API->>API: make_tensor_view(KV, shape, 0)<br/>→ K_view
    API->>API: make_tensor_view(KV, shape, stride)<br/>→ V_view
    
    API->>Kernel: fused_attn_*_fwd(Q, K_view, V_view)
    Note over Kernel: Operates on unpacked<br/>Q, K, V tensors
    Kernel-->>API: Output O
    API-->>User: Return attention output
Loading

8 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <[email protected]>
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR successfully refactors fused attention functions by eliminating code duplication across three kernel types (arbitrary seqlen, max512, and fp8). The key changes include:

  • New helper functions: Added make_tensor_view(), calculate_qkv_stride(), calculate_kv_stride(), and shape calculation helpers in an anonymous namespace
  • Unpacking moved up: KV/QKV tensor unpacking now happens in the common API layer (nvte_fused_attn_fwd_qkvpacked, nvte_fused_attn_fwd_kvpacked) instead of within each kernel-specific function
  • Kernel simplification: Removed _qkvpacked and _kvpacked variants of kernel functions (fused_attn_max_512_fwd_qkvpacked, fused_attn_max_512_fwd_kvpacked, etc.), reducing code duplication
  • Bug fix: Correctly uses h_kv instead of h_q for KV-packed stride calculations, fixing a bug in GQA (Grouped Query Attention) scenarios where the number of KV heads differs from Q heads
  • APIs marked deprecated: Added deprecation notices to nvte_fused_attn_fwd_qkvpacked and nvte_fused_attn_fwd_kvpacked functions

The refactoring maintains functional equivalence while improving code maintainability. The stride calculations are now correct for all layout types (NVTE_HD_2HD, NVTE_HD_H2D, NVTE_3HD, NVTE_H3D).

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - it's a well-executed refactoring that improves code quality and fixes stride calculation bugs
  • The refactoring is clean and well-structured with proper helper functions. The stride calculations correctly use h_kv for KV-packed tensors, fixing potential bugs in GQA scenarios. All unpacking logic has been centralized with consistent patterns across forward/backward passes and all kernel types. The changes maintain backward compatibility by keeping deprecated API functions.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/fused_attn.cpp 5/5 Major refactoring that removes code duplication by moving KV/QKV unpacking logic from kernel-specific functions to the common API layer. Correctly fixes stride calculation bugs for GQA scenarios where h_q != h_kv.

Sequence Diagram

sequenceDiagram
    participant User as JAX/PyTorch API
    participant CommonAPI as nvte_fused_attn_fwd_kvpacked
    participant Helpers as Helper Functions
    participant Kernel as fused_attn_max_512_fwd

    User->>CommonAPI: Call with packed KV tensor
    CommonAPI->>Helpers: calculate_kv_stride(layout_group, dtype, h_kv, d)
    Helpers-->>CommonAPI: stride (bytes)
    CommonAPI->>Helpers: calculate_kv_unpacked_shape(KV, layout, h_kv, d)
    Helpers-->>CommonAPI: unpacked_shape
    CommonAPI->>Helpers: make_tensor_view(KV, shape, 0)
    Helpers-->>CommonAPI: K_view
    CommonAPI->>Helpers: make_tensor_view(KV, shape, stride)
    Helpers-->>CommonAPI: V_view
    CommonAPI->>Kernel: fused_attn_max_512_fwd(Q, K_view, V_view, ...)
    Kernel-->>CommonAPI: output
    CommonAPI-->>User: return output
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Nov 6, 2025

/te-ci jax


/*! \brief Compute dot product attention with packed QKV input.
*
* \warning This API is **deprecated**.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would \deprecated be a better marking for this message in the document?

Could we add something like this to give user warnings at compile time as well?

[[deprecated("nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() instead.")]]
void nvte_fused_attn_fwd_qkvpacked() {
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants