Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Nov 6, 2025

Description

JAX calls nvte_fused_attn_fwd_kvpacked(), nvte_fused_attn_fwd_qkvpacked() or nvte_fused_attn_fwd(). First two will be deprecated by #2287, so this PR changes the jax extension code to use only last one.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL pggPL changed the title [JAX] Make all jax attention calls to use non-packed common calls [JAX] Make all jax attention calls use non-packed common calls Nov 6, 2025
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactors JAX attention extension to use only nvte_fused_attn_fwd/bwd instead of the deprecated packed variants (nvte_fused_attn_fwd_qkvpacked and nvte_fused_attn_fwd_kvpacked). The PR moves pointer arithmetic from the common API layer into the JAX extension code.

Key changes:

  • Unified all three layout types (QKV packed, KV packed, separate) to call single nvte_fused_attn_fwd/bwd API
  • Added pointer arithmetic in JAX extension to extract K and V pointers from packed tensors
  • Removed unused tensor shape definitions and layout-specific branching in workspace size calculations
  • Updated gradient zeroing logic in backward pass to correctly handle packed tensor memory layouts

Critical issue found:

  • Lines 287 and 517: Stride calculation for KV-packed layout uses qk_head_dim but should use v_head_dim since KV packed tensors have shape [batch*seqlen, 2, num_gqa_groups, v_head_dim]

Confidence Score: 1/5

  • This PR contains critical pointer arithmetic bugs that will cause memory corruption or incorrect results
  • Score of 1 (critical issues) due to incorrect stride calculation in KV-packed layout at lines 287 and 517. Using qk_head_dim instead of v_head_dim will cause V pointer to point to wrong memory location when these dimensions differ, leading to incorrect attention computation or potential memory access violations
  • transformer_engine/jax/csrc/extensions/attention.cpp - lines 287 and 517 must be fixed before merge

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/csrc/extensions/attention.cpp 1/5 Refactors JAX attention to use only nvte_fused_attn_fwd/bwd, but has critical pointer arithmetic bug in KV-packed layout stride calculation using qk_head_dim instead of v_head_dim

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Extension
    participant Old as Old API (deprecated)
    participant New as nvte_fused_attn_fwd
    participant Kernel as Attention Kernel

    Note over JAX,Kernel: Before PR #2358
    JAX->>Old: nvte_fused_attn_fwd_qkvpacked(QKV)
    Old->>Kernel: Extract Q,K,V pointers internally
    Kernel-->>Old: Result
    Old-->>JAX: Output

    JAX->>Old: nvte_fused_attn_fwd_kvpacked(Q, KV)
    Old->>Kernel: Extract K,V pointers internally
    Kernel-->>Old: Result
    Old-->>JAX: Output

    Note over JAX,Kernel: After PR #2358
    JAX->>JAX: Extract Q,K,V pointers from packed data
    JAX->>New: nvte_fused_attn_fwd(Q, K, V)
    New->>Kernel: Forward separate pointers
    Kernel-->>New: Result
    New-->>JAX: Output

    Note over JAX: Pointer arithmetic handles:<br/>- NVTE_3HD: QKV packed<br/>- NVTE_HD_2HD: KV packed<br/>- NVTE_HD_HD_HD: All separate
Loading

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

// Python passes: q=query, k=packed_kv, v=dummy
// Extract V pointer from the packed k data
// Note: For KV packed, qk_head_dim must equal v_head_dim
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: stride calculation uses qk_head_dim but KV packed tensor has shape [batch*seqlen, 2, num_gqa_groups, v_head_dim] - should use v_head_dim

Suggested change
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
size_t stride = (typeToSize(dtype) * num_gqa_groups * v_head_dim);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment can be ignored

cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream);
// Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim]
// Note: For KV packed, qk_head_dim must equal v_head_dim
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: stride calculation uses qk_head_dim but KV packed tensor has shape [batch*seqlen, 2, num_gqa_groups, v_head_dim] - should use v_head_dim

Suggested change
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
size_t stride = (typeToSize(dtype) * num_gqa_groups * v_head_dim);

pggPL and others added 2 commits November 6, 2025 23:01
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Refactors JAX attention code to use unified nvte_fused_attn_fwd() and nvte_fused_attn_bwd() APIs instead of deprecated packed variants (nvte_fused_attn_fwd_kvpacked() and nvte_fused_attn_fwd_qkvpacked()). The pointer arithmetic for extracting K/V pointers from packed tensors is now handled in the JAX layer.

Key Changes:

  • Removed conditional calls to packed-specific attention APIs
  • Added pointer calculation logic to extract K/V pointers from packed QKV/KV tensors based on layout
  • Unified all attention calls to use the single nvte_fused_attn_fwd/bwd API
  • Updated workspace size calculation functions similarly

Issue Found:

  • For KV-packed layout (NVTE_HD_2HD), stride calculation uses qk_head_dim but should use v_head_dim to match the actual tensor shape [batch*seqlen, 2, num_gqa_groups, v_head_dim]. While enforced equal by runtime check, using v_head_dim is semantically correct.

Confidence Score: 4/5

  • Safe to merge after fixing stride calculation to use v_head_dim instead of qk_head_dim for KV-packed layout
  • The refactoring is well-structured and aligns with the goal of deprecating packed-specific APIs. However, the stride calculation issue (using qk_head_dim instead of v_head_dim) in the KV-packed layout needs to be fixed for semantic correctness, even though runtime checks enforce equality. The logic is sound otherwise, with proper handling of different layouts and appropriate memory clearing for ragged sequences.
  • transformer_engine/jax/csrc/extensions/attention.cpp - Fix stride calculation on lines 290 and 523 to use v_head_dim instead of qk_head_dim

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/csrc/extensions/attention.cpp 4/5 Refactors JAX attention to use unified nvte_fused_attn_fwd/bwd API; has stride calculation issue using qk_head_dim instead of v_head_dim for KV-packed layout

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Python Layer
    participant FwdImpl as FusedAttnForwardImpl
    participant LayoutCheck as Layout Detection
    participant PtrCalc as Pointer Calculation
    participant API as nvte_fused_attn_fwd

    JAX->>FwdImpl: Call with q, k, v pointers
    FwdImpl->>LayoutCheck: Check layout_group (NVTE_3HD/HD_2HD/HD_HD_HD)
    
    alt NVTE_3HD (QKV packed)
        LayoutCheck->>PtrCalc: Extract K, V from packed Q
        PtrCalc->>PtrCalc: k_ptr = q + stride<br/>v_ptr = q + 2*stride<br/>stride = typeSize * attn_heads * qk_head_dim
    else NVTE_HD_2HD (KV packed)
        LayoutCheck->>PtrCalc: Extract V from packed K
        PtrCalc->>PtrCalc: v_ptr = k + stride<br/>stride = typeSize * num_gqa_groups * qk_head_dim
    else NVTE_HD_HD_HD (separate)
        LayoutCheck->>PtrCalc: Use pointers as-is
    end
    
    PtrCalc->>API: Call with separate q_ptr, k_ptr, v_ptr
    API-->>FwdImpl: Return results
    FwdImpl-->>JAX: Return output
Loading

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Nov 6, 2025

/te-ci jax

phu0ngng
phu0ngng previously approved these changes Nov 13, 2025
Copy link
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 13, 2025

Greptile Overview

Greptile Summary

Refactors JAX attention extension to use only nvte_fused_attn_fwd and nvte_fused_attn_bwd instead of deprecated nvte_fused_attn_fwd_kvpacked and nvte_fused_attn_fwd_qkvpacked functions (related to #2287).

Key changes:

  • Removes conditional branching based on layout groups in forward/backward passes
  • Adds pointer arithmetic to extract K and V pointers from packed QKV/KV tensors
  • Simplifies workspace size calculation by removing packed-specific tensor definitions
  • Improves cudaMemset logic for ragged tensors to correctly handle packed layouts

The refactoring consolidates three code paths into one, reducing duplication while maintaining functional equivalence.

Confidence Score: 3/5

  • PR has logical correctness issues in stride calculation that should be fixed before merging
  • The refactoring successfully consolidates code paths and the overall logic is sound. However, there are stride calculation issues in the KV packed case (lines 290, 523) that use qk_head_dim instead of v_head_dim. While these are enforced to be equal via NVTE_CHECK, using v_head_dim is more semantically correct since we're computing the offset to the V tensor. This has already been flagged in previous comments.
  • Pay close attention to transformer_engine/jax/csrc/extensions/attention.cpp lines 290 and 523 - stride calculations for KV packed tensors

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/csrc/extensions/attention.cpp 4/5 Refactors JAX attention calls to use only nvte_fused_attn_fwd/bwd instead of deprecated packed variants. Includes pointer arithmetic for unpacking QKV/KV tensors and improved cudaMemset logic for ragged tensors.

Sequence Diagram

sequenceDiagram
    participant JAX as JAX Python
    participant CPP as attention.cpp
    participant NVTE as nvte_fused_attn_fwd
    
    JAX->>CPP: FusedAttnForwardImpl(q, k, v, ...)
    Note over CPP: Determine layout_group from qkv_layout
    
    alt QKV Packed (NVTE_3HD)
        Note over CPP: Extract K, V from packed Q tensor<br/>stride = typeToSize * attn_heads * qk_head_dim<br/>k_ptr = q + stride<br/>v_ptr = q + 2*stride
    else KV Packed (NVTE_HD_2HD)
        Note over CPP: Extract V from packed K tensor<br/>stride = typeToSize * num_gqa_groups * qk_head_dim<br/>k_ptr = k<br/>v_ptr = k + stride
    else Separate (NVTE_HD_HD_HD)
        Note over CPP: Use pointers as-is<br/>q_ptr = q, k_ptr = k, v_ptr = v
    end
    
    CPP->>CPP: Create TensorWrappers(q_ptr, k_ptr, v_ptr)
    CPP->>NVTE: nvte_fused_attn_fwd(q_tensor, k_tensor, v_tensor, ...)
    NVTE-->>CPP: Compute attention
    CPP-->>JAX: Return output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Nov 13, 2025

/te-ci jax

Copy link
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: Please ignore the below comment (I incorrectly thought that you had launched only L0 tests)

Please run L1 tests on this as well so as to exercise the L1 dist attn tests due to the nature of the change (though L0 should mostly be enough - better safe than sorry :) )

auto ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As part of this consolidation, will be lose this check ?
Is that okay or needs to be looked into ?

"For QKV packed layout, qk_head_dim must equal v_head_dim");
size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim);
q_ptr = q;
k_ptr = static_cast<void *>(static_cast<int8_t *>(q) + stride);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick question @pggPL : why the choice of int8_t? (for static casting the q,k void pointers)

Comment on lines +272 to +274
// QKV packed in q: [batch*seqlen, 3, heads, dim]
// Python passes: q=packed_qkv, k=dummy, v=dummy
// Extract K and V pointers from the packed q data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments

// Python passes: q=query, k=packed_kv, v=dummy
// Extract V pointer from the packed k data
// Note: For KV packed, qk_head_dim must equal v_head_dim
size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment can be ignored

@KshitijLakhani
Copy link
Collaborator

Thanks for PR 2287. Quick nit from 2287: In calculate_qkv_stride could you add "stride in bytes" in the comments instead of just "stride" ?

v_shape = k_shape;
}

auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need these tensor wrappers? Or can we pass the pointers directly? They don't seem to do anything.

Copy link
Collaborator

@mgoldfarb-nvidia mgoldfarb-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM assuming out CI passes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants