Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Oct 1, 2025

Description

PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX.

Fixes #2070

BEFORE
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    1.97s | avg:   0.16s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  612.61s | avg:   3.83s
test_context_parallel_allgather_attn_shardy                  |  20x |   90.95s | avg:   4.55s
test_context_parallel_ring_attn                              | 640x | 1042.37s | avg:   1.63s
test_context_parallel_ring_attn_shardy                       |  20x |   37.74s | avg:   1.89s
test_cross_attn                                              |   6x |   31.82s | avg:   5.30s
test_distributed_gemm                                        |   6x |    6.10s | avg:   1.02s
test_layernorm                                               | 144x |   81.39s | avg:   0.57s
test_layernorm_mlp_grad                                      | 240x |  301.51s | avg:   1.26s
test_layernorm_mlp_grad_shardy                               | 240x |  293.58s | avg:   1.22s
test_layernorm_mlp_layer                                     |  48x |   21.58s | avg:   0.45s
test_layernorm_mlp_layer_fp8                                 | 192x |   81.58s | avg:   0.42s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |   91.23s | avg:   0.48s
test_layernorm_mlp_layer_shardy                              |  48x |   25.98s | avg:   0.54s
test_rmsnorm                                                 |  72x |   29.43s | avg:   0.41s
test_self_attn                                               |  18x |   89.75s | avg:   4.99s
test_self_attn_shardy                                        |   6x |   17.32s | avg:   2.89s
test_softmax                                                 | 288x |  185.44s | avg:   0.64s
test_softmax_gspmd                                           |  24x |   13.07s | avg:   0.54s
test_te_distributed_dense_grad                               |   6x |    5.12s | avg:   0.85s
================================================================================
TOTAL RUNTIME                                                |      | 3060.56s |
================================================================================

AFTER
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  12x |    2.20s | avg:   0.18s
test_autocast_with_mesh_resource                             |   1x |    0.00s | avg:   0.00s
test_context_parallel_allgather_attn                         | 160x |  587.44s | avg:   3.67s
test_context_parallel_allgather_attn_shardy                  |  20x |   87.95s | avg:   4.40s
test_context_parallel_ring_attn                              | 640x | 1037.16s | avg:   1.62s
test_context_parallel_ring_attn_shardy                       |  20x |   41.83s | avg:   2.09s
test_cross_attn                                              |  18x |   89.76s | avg:   4.99s
test_distributed_gemm                                        |   6x |    5.74s | avg:   0.96s
test_layernorm                                               | 144x |   83.85s | avg:   0.58s
test_layernorm_mlp_grad                                      | 240x |  301.73s | avg:   1.26s
test_layernorm_mlp_grad_shardy                               | 240x |  309.08s | avg:   1.29s
test_layernorm_mlp_layer                                     |  48x |   24.98s | avg:   0.52s
test_layernorm_mlp_layer_fp8                                 | 192x |   89.17s | avg:   0.46s
test_layernorm_mlp_layer_fp8_shardy                          | 192x |   92.58s | avg:   0.48s
test_layernorm_mlp_layer_shardy                              |  48x |   26.29s | avg:   0.55s
test_rmsnorm                                                 |  72x |   29.52s | avg:   0.41s
test_self_attn                                               |  54x |  259.63s | avg:   4.81s
test_self_attn_shardy                                        |  18x |   43.51s | avg:   2.42s
test_softmax                                                 | 288x |  183.87s | avg:   0.64s
test_softmax_gspmd                                           |  24x |   12.72s | avg:   0.53s
test_te_distributed_dense_grad                               |   6x |    4.74s | avg:   0.79s
================================================================================
TOTAL RUNTIME                                                |      | 3313.74s |
================================================================================

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 4 commits October 1, 2025 14:45
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pre-commit-ci bot and others added 5 commits October 2, 2025 15:54
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 6, 2025

/te-ci jax

pggPL and others added 2 commits October 7, 2025 14:30
@phu0ngng phu0ngng self-requested a review October 8, 2025 18:33
pggPL and others added 7 commits October 14, 2025 14:41
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 14, 2025

/te-ci jax L1

@pggPL
Copy link
Collaborator Author

pggPL commented Oct 15, 2025

/te-ci jax L1

@pggPL
Copy link
Collaborator Author

pggPL commented Oct 15, 2025

/te-ci jax L1

pggPL and others added 6 commits October 16, 2025 19:48
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset)


def jax_general_softmax(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This softmax is copied from JAX. In JAX there is also custom jvp for this softmax, it can be generalized to support softmax offset. I don't know how efficient we want unfused attention to be - I decided to not include custom jvp.

@pggPL pggPL marked this pull request as ready for review October 17, 2025 11:04
@pggPL pggPL requested a review from KshitijLakhani October 17, 2025 11:05
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 17, 2025

/te-ci jax L1

pggPL and others added 4 commits October 21, 2025 18:25
@pggPL
Copy link
Collaborator Author

pggPL commented Oct 21, 2025

/te-ci jax L1

Signed-off-by: Pawel Gadzinski <[email protected]>
@jberchtold-nvidia
Copy link
Collaborator

@KshitijLakhani can you review from the JAX attention side?

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR successfully adds sink attention support to JAX, following PR #2148 which added the feature to PyTorch. The implementation introduces three softmax variants: vanilla (standard softmax), off-by-one (adds +1 to denominator), and learnable (uses trainable offset parameters).

Key Changes:

  • Adds AttnSoftmaxType enum with three variants (VANILLA, OFF_BY_ONE, LEARNABLE)
  • Implements learnable softmax_offset parameter with shape [1, num_heads, 1, 1] that receives gradients during training
  • Updates attention APIs throughout the stack: Python → C++ extensions → CUDA kernels
  • Renames SoftmaxType to SoftmaxFusion to distinguish between fusion modes and sink attention types
  • Extends test coverage with parameterized tests for all three softmax variants
  • Context parallel support explicitly excludes sink attention (uses vanilla softmax only)

Implementation Quality:

  • Clean separation of concerns between SoftmaxFusion (masked/causal/scaled) and AttnSoftmaxType (sink variants)
  • Proper sharding annotations for learnable parameters using HEAD_AXES
  • Comprehensive test coverage including forward, backward, and distributed scenarios
  • C++ refactoring simplifies QKV layout handling by consolidating to unified nvte_fused_attn_fwd API
  • Backward compatibility maintained through optional parameters with sensible defaults

Confidence Score: 5/5

  • This PR is safe to merge with no critical issues found
  • The implementation is well-structured, thoroughly tested, and follows established patterns from the PyTorch implementation. Tests show comparable runtime performance. The code includes proper gradient handling for learnable parameters, correct sharding for distributed training, and comprehensive test coverage across multiple configurations. No logic errors, syntax issues, or security concerns were identified.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 5/5 Adds AttnSoftmaxType enum with vanilla, off-by-one, and learnable softmax variants. Updates fused attention APIs to accept softmax_type and optional softmax_offset parameters. Clean implementation following existing patterns.
transformer_engine/jax/flax/transformer.py 5/5 Adds softmax_type parameter to attention modules. Implements learnable softmax offset as a trainable parameter with proper sharding. Updates both unfused and fused attention implementations to support sink attention.
transformer_engine/jax/cpp_extensions/attention.py 5/5 Updates C++ extension wrapper to pass softmax_type and softmax_offset through the FFI boundary. Modifies FusedAttnHelper to include softmax type in backend selection logic.
transformer_engine/jax/csrc/extensions/attention.cpp 4/5 Refactors C++ implementation to handle softmax_offset tensor in forward/backward passes. Simplifies QKV layout handling by using unified nvte_fused_attn_fwd API instead of separate qkvpacked/kvpacked functions. Adds softmax_offset to aux tensor pack.
transformer_engine/jax/softmax.py 5/5 Adds AttnSoftmaxType parameter to softmax primitives and modules. Distinguishes between SoftmaxFusion (fused operations) and AttnSoftmaxType (sink attention variants).
transformer_engine/jax/flax/module.py 5/5 Updates Softmax module to support softmax_type parameter and optional softmax_offset for learnable sink attention. Implements off-by-one and learnable softmax logic.
tests/jax/test_fused_attn.py 5/5 Adds comprehensive test coverage for all three softmax types (vanilla, off-by-one, learnable). Tests include forward, backward, and gradient checking across multiple configurations.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Flax as Flax Module<br/>(transformer.py)
    participant Attn as Attention Layer<br/>(attention.py)
    participant CPP as C++ Extension<br/>(attention.py)
    participant CUDA as CUDA Kernel<br/>(nvte_fused_attn_fwd)
    
    User->>Flax: Call DotProductAttention
    Note over Flax: softmax_type parameter<br/>(vanilla/off_by_one/learnable)
    
    alt Learnable Softmax
        Flax->>Flax: Initialize softmax_offset<br/>parameter [1, H, 1, 1]
    end
    
    Flax->>Attn: fused_attn(qkv, bias, softmax_offset, ...)
    Note over Attn: Pack softmax_type and<br/>softmax_offset into call
    
    Attn->>CPP: fused_attn_fwd(qkv, bias,<br/>softmax_offset, sequence_descriptor)
    Note over CPP: Add softmax_offset to<br/>aux_output_tensors
    
    CPP->>CUDA: nvte_fused_attn_fwd(..., softmax_type)
    Note over CUDA: Compute attention with<br/>sink softmax variant
    
    alt Off-by-One Softmax
        CUDA->>CUDA: S[i] = exp(S[i])/(1 + sum(exp(S)))
    else Learnable Softmax
        CUDA->>CUDA: S[j,i] = exp(S[j,i])/(exp(alpha[j]) + sum(exp(S[j,:])))
    else Vanilla Softmax
        CUDA->>CUDA: S[i] = exp(S[i])/sum(exp(S))
    end
    
    CUDA-->>CPP: output, softmax_aux, rng_state
    CPP-->>Attn: output
    
    alt Training Mode & Learnable Softmax
        Attn->>CPP: fused_attn_bwd(dz, ...)
        Note over CPP: Compute gradients for<br/>Q, K, V, bias, softmax_offset
        CPP->>CUDA: nvte_fused_attn_bwd(...)
        CUDA-->>CPP: grad_q, grad_k, grad_v,<br/>grad_bias, grad_softmax_offset
        CPP-->>Attn: gradients
        Attn-->>Flax: gradients
        Note over Flax: Update softmax_offset<br/>parameter via optimizer
    end
    
    Attn-->>User: attention output
Loading

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@pggPL
Copy link
Collaborator Author

pggPL commented Nov 4, 2025

/te-ci jax L1

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds support for sink attention (attention with learnable softmax variants) to JAX, following the PyTorch implementation from PR #2148. The implementation introduces three softmax types: VANILLA (standard), OFF_BY_ONE (adds 1 to denominator), and LEARNABLE (uses learnable per-head offset parameters).

Key Changes:

  • Adds AttnSoftmaxType enum with three variants: VANILLA_SOFTMAX, OFF_BY_ONE_SOFTMAX, LEARNABLE_SOFTMAX
  • Threads softmax_type and softmax_offset parameters throughout the attention pipeline from Flax modules through JAX primitives to C++/cuDNN backends
  • Renames SoftmaxType to SoftmaxFusion to distinguish kernel fusion strategies from sink attention variants
  • Updates all attention primitives (standard, context parallel with all-gather, ring attention) to handle softmax_offset
  • Implements proper gradient computation for learnable softmax parameters with appropriate all-reduce operations
  • Adds comprehensive test coverage including forward/backward passes and distributed scenarios

Implementation Details:

  • For LEARNABLE_SOFTMAX, creates a learnable parameter of shape (1, num_heads, 1, 1) with proper sharding by head dimension
  • OFF_BY_ONE_SOFTMAX is handled by setting softmax_offset=1.0
  • Context parallel paths (ring attention, all-gather) return dummy gradients for softmax_offset as they don't support learnable variants
  • The C++ layer properly packs softmax_offset tensors into the cuDNN tensor pack for both forward and backward passes
  • Refactored C++ code consolidates multiple layout-specific calls into unified nvte_fused_attn_fwd/bwd calls

Testing:
Test runtime increased from ~3061s to ~3314s (+8%) due to additional test cases covering the three softmax variants

Confidence Score: 4/5

  • This PR is generally safe to merge with one potential mathematical issue to verify
  • The implementation is comprehensive and well-structured with proper gradient handling, sharding, and test coverage. However, there's one potential mathematical correctness issue in jax_general_softmax (transformer_engine/jax/cpp_extensions/softmax.py:834-853) where the offset handling needs verification - specifically whether callers pass the raw learnable parameter or its exponential for LEARNABLE_SOFTMAX
  • transformer_engine/jax/cpp_extensions/softmax.py - verify that the jax_general_softmax function receives the correct pre-exponentiated offset values for LEARNABLE_SOFTMAX

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 4/5 Adds AttnSoftmaxType enum with support for VANILLA, OFF_BY_ONE, and LEARNABLE softmax variants. Updates function signatures to accept softmax_type and softmax_offset parameters. Implementation looks solid with proper enum handling and parameter validation.
transformer_engine/jax/cpp_extensions/attention.py 4/5 Extensive changes to support softmax_offset throughout the attention primitives. Properly handles forward/backward passes, sharding, and context parallelism. The gradient handling for softmax_offset includes proper all-reduce for learnable softmax and dummy returns for CP paths that don't use it.
transformer_engine/jax/csrc/extensions/attention.cpp 4/5 C++ implementation updated to thread softmax_type and softmax_offset through the call chain. Adds proper tensor pack handling for softmax_offset in both forward and backward passes. The refactoring consolidates multiple layout-specific calls into a unified nvte_fused_attn_fwd call.
transformer_engine/jax/cpp_extensions/softmax.py 3/5 Adds jax_general_softmax for sink attention support and updates softmax functions to accept softmax_offset. The implementation adds offset to the denominator, but needs verification that callers pass the correct pre-exponentiated values for LEARNABLE_SOFTMAX.
transformer_engine/jax/flax/transformer.py 4/5 Adds softmax_type parameter to both fused and unfused attention implementations. Creates learnable softmax_offset parameter for LEARNABLE_SOFTMAX with proper sharding. Handles OFF_BY_ONE internally. PRE_SCALE_BIAS handling sets bias to None after adding to prevent double-addition.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Flax as Flax Transformer
    participant Attn as JAX Attention
    participant Prim as Attention Primitive
    participant CPP as C++ Extension
    participant cuDNN as cuDNN Backend

    User->>Flax: Call attention with softmax_type
    
    alt LEARNABLE_SOFTMAX
        Flax->>Flax: Initialize learnable softmax_offset param
    else OFF_BY_ONE_SOFTMAX
        Flax->>Flax: Set softmax_offset = 1.0
    else VANILLA_SOFTMAX
        Flax->>Flax: Set softmax_offset = empty
    end
    
    Flax->>Attn: fused_attn(qkv, bias, softmax_offset, ...)
    Attn->>Attn: Apply sharding constraints to softmax_offset
    Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v, bias, softmax_offset, ...)
    
    Prim->>CPP: FusedAttnForwardImpl(q, k, v, bias, softmax_offset, ...)
    CPP->>CPP: PrepareFusedAttnForwardAuxTensors (adds softmax_offset to tensor pack)
    CPP->>cuDNN: nvte_fused_attn_fwd(tensor_pack with softmax_offset)
    cuDNN-->>CPP: output, softmax_aux
    CPP-->>Prim: output, softmax_aux, rng_state
    Prim-->>Attn: output
    Attn-->>Flax: output
    
    Note over User,cuDNN: Backward Pass
    
    User->>Flax: Gradient computation
    Flax->>Attn: fused_attn backward
    Attn->>Prim: FusedAttnBwdPrimitive.bind(...)
    Prim->>CPP: FusedAttnBackwardImpl(..., softmax_offset)
    CPP->>cuDNN: nvte_fused_attn_bwd(...)
    cuDNN-->>CPP: dq, dk, dv, dbias, dsoftmax_offset
    CPP-->>Prim: dq, dk, dv, dbias, dsoftmax_offset
    
    alt LEARNABLE_SOFTMAX
        Prim->>Prim: all_reduce dsoftmax_offset across DP/FSDP
    end
    
    Prim-->>Attn: gradients
    Attn-->>Flax: gradients
    Flax-->>User: Updated parameters
Loading

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits November 5, 2025 14:31
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds support for sink attention to JAX, implementing three softmax variants: VANILLA, OFF_BY_ONE (zero sink), and LEARNABLE (learnable sink). The implementation follows the pattern established in PR #2148 for PyTorch.

Key Changes

  • New AttnSoftmaxType enum: Defines three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE)
  • softmax_offset parameter: Added throughout the attention pipeline to support sink attention
    • For LEARNABLE_SOFTMAX: a learnable parameter [1, num_heads, 1, 1]
    • For OFF_BY_ONE_SOFTMAX: treated as an implicit offset of 0
  • Renamed SoftmaxTypeSoftmaxFusion: Distinguishes fusion strategy from softmax variant
  • Backend support: Updated C++ extensions and cuDNN integration to handle new softmax types
  • Comprehensive tests: Added test coverage for all three softmax types

Implementation Quality

Strengths:

  • Well-structured changes following existing code patterns
  • Comprehensive test coverage with reference implementations
  • Proper gradient handling for learnable parameters
  • Clean separation between fusion strategy and softmax type

Critical Issue Found:

  • Bug in transformer_engine/jax/flax/module.py:198: OFF_BY_ONE_SOFTMAX incorrectly uses softmax_offset = 1.0 instead of softmax_offset = 0.0. This will produce incorrect attention weights when logits exceed 1.0. The test reference implementation correctly uses a zero logit, but the optimized path has the wrong value.

Confidence Score: 3/5

  • This PR should not be merged without fixing the OFF_BY_ONE_SOFTMAX bug
  • The implementation is well-structured and comprehensive, but contains a critical logical error in the OFF_BY_ONE_SOFTMAX implementation that will cause incorrect attention computation. The bug is in module.py:198 where softmax_offset = 1.0 should be softmax_offset = 0.0. This discrepancy between the test reference implementation (which correctly uses zero) and the optimized path means tests may pass while the production code produces wrong results in certain cases.
  • transformer_engine/jax/flax/module.py - Fix OFF_BY_ONE_SOFTMAX offset value from 1.0 to 0.0

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/attention.py 3/5 Added AttnSoftmaxType enum and softmax_offset parameter throughout attention pipeline. Implementation looks correct.
transformer_engine/jax/cpp_extensions/softmax.py 2/5 Added jax_general_softmax with offset support. Implementation is mathematically sound but usage in module.py has a bug.
transformer_engine/jax/flax/module.py 2/5 Updated Softmax module with sink attention support. Critical bug: OFF_BY_ONE_SOFTMAX uses offset=1.0 instead of offset=0.0.
transformer_engine/jax/flax/transformer.py 4/5 Added softmax_type parameter to attention modules with learnable parameter initialization for LEARNABLE_SOFTMAX. Implementation looks correct.
tests/jax/test_fused_attn.py 5/5 Comprehensive test coverage added for all three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE). Reference implementation matches expected behavior.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant DPA as DotProductAttention
    participant Fused as _FusedDotProductAttention
    participant Attn as fused_attn
    participant Prim as FusedAttnFwdPrimitive
    participant CPP as C++ Backend
    participant cuDNN as cuDNN Kernel

    Note over User,cuDNN: Sink Attention Flow (JAX)
    
    User->>DPA: __call__(query, key, value, softmax_type='off_by_one')
    DPA->>Fused: forward with softmax_type
    
    alt softmax_type == LEARNABLE_SOFTMAX
        Fused->>Fused: Initialize learnable param<br/>softmax_offset [1, h, 1, 1]
    else softmax_type == OFF_BY_ONE_SOFTMAX
        Note over Fused: No offset param needed<br/>(handled by backend)
    end
    
    Fused->>Attn: fused_attn(qkv, bias, softmax_offset,<br/>sequence_descriptor, softmax_type)
    Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v,<br/>bias, softmax_offset, ...)
    
    Prim->>CPP: FusedAttnForwardFFI(q, k, v, bias,<br/>softmax_offset, softmax_type)
    
    CPP->>CPP: Prepare tensor pack:<br/>- softmax_aux<br/>- rng_state<br/>- bias (if provided)<br/>- softmax_offset (if provided)
    
    CPP->>cuDNN: nvte_fused_attn_fwd(q, k, v, bias,<br/>softmax_offset, softmax_type, ...)
    
    alt softmax_type == VANILLA_SOFTMAX
        Note over cuDNN: S = exp(QK^T) / sum(exp(QK^T))
    else softmax_type == OFF_BY_ONE_SOFTMAX
        Note over cuDNN: S = exp(QK^T) / (1 + sum(exp(QK^T)))
    else softmax_type == LEARNABLE_SOFTMAX
        Note over cuDNN: S = exp(QK^T) / (exp(alpha) + sum(exp(QK^T)))
    end
    
    cuDNN-->>CPP: output, softmax_aux, rng_state
    CPP-->>Prim: output, softmax_aux, rng_state
    Prim-->>Attn: output
    Attn-->>Fused: attention output
    Fused-->>DPA: attention output
    DPA-->>User: final output
    
    Note over User,cuDNN: Backward Pass (if training)
    User->>DPA: grad(output)
    DPA->>Attn: backward
    Attn->>Prim: FusedAttnBwdPrimitive
    Prim->>CPP: FusedAttnBackwardFFI
    CPP->>cuDNN: nvte_fused_attn_bwd
    
    alt softmax_type == LEARNABLE_SOFTMAX
        Note over cuDNN: Compute grad_softmax_offset
        cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, grad_softmax_offset
    else
        cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, None
    end
    
    Prim-->>User: gradients
Loading

15 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 197 to 198
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
softmax_offset = 1.0
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: softmax_offset = 1.0 is incorrect for OFF_BY_ONE_SOFTMAX. Should be softmax_offset = 0.0.

OFF_BY_ONE_SOFTMAX adds +1 to the denominator by including an extra logit position with value 0. The test implementation (test_fused_attn.py:111) confirms this by appending jnp.zeros. The jax_general_softmax function expects offset to be the logit value, not the contribution to the denominator. With offset=0.0, after max-normalization: exp(0 - x_max) is added to denominator, which correctly contributes exp(-x_max). When all terms are normalized by dividing by exp(-x_max), this becomes +1.

Suggested change
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
softmax_offset = 1.0
if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
softmax_offset = 0.0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good catch. It's not good that tests I ran locally didn't catch it - this is the same tolerance issuance as in PR #2300

Signed-off-by: Pawel Gadzinski <[email protected]>
@pggPL
Copy link
Collaborator Author

pggPL commented Nov 5, 2025

/te-ci jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add attention sink to flash attention

4 participants