-
Notifications
You must be signed in to change notification settings - Fork 540
[JAX] Add support for sink attention in JAX #2225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax |
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci jax L1 |
for more information, see https://pre-commit.ci
|
/te-ci jax L1 |
|
/te-ci jax L1 |
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
| return jax_scaled_masked_softmax(logits, mask, scale_factor, softmax_offset) | ||
|
|
||
|
|
||
| def jax_general_softmax( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This softmax is copied from JAX. In JAX there is also custom jvp for this softmax, it can be generalized to support softmax offset. I don't know how efficient we want unfused attention to be - I decided to not include custom jvp.
|
/te-ci jax L1 |
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci jax L1 |
Signed-off-by: Pawel Gadzinski <[email protected]>
|
@KshitijLakhani can you review from the JAX attention side? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR successfully adds sink attention support to JAX, following PR #2148 which added the feature to PyTorch. The implementation introduces three softmax variants: vanilla (standard softmax), off-by-one (adds +1 to denominator), and learnable (uses trainable offset parameters).
Key Changes:
- Adds
AttnSoftmaxTypeenum with three variants (VANILLA, OFF_BY_ONE, LEARNABLE) - Implements learnable
softmax_offsetparameter with shape[1, num_heads, 1, 1]that receives gradients during training - Updates attention APIs throughout the stack: Python → C++ extensions → CUDA kernels
- Renames
SoftmaxTypetoSoftmaxFusionto distinguish between fusion modes and sink attention types - Extends test coverage with parameterized tests for all three softmax variants
- Context parallel support explicitly excludes sink attention (uses vanilla softmax only)
Implementation Quality:
- Clean separation of concerns between
SoftmaxFusion(masked/causal/scaled) andAttnSoftmaxType(sink variants) - Proper sharding annotations for learnable parameters using
HEAD_AXES - Comprehensive test coverage including forward, backward, and distributed scenarios
- C++ refactoring simplifies QKV layout handling by consolidating to unified
nvte_fused_attn_fwdAPI - Backward compatibility maintained through optional parameters with sensible defaults
Confidence Score: 5/5
- This PR is safe to merge with no critical issues found
- The implementation is well-structured, thoroughly tested, and follows established patterns from the PyTorch implementation. Tests show comparable runtime performance. The code includes proper gradient handling for learnable parameters, correct sharding for distributed training, and comprehensive test coverage across multiple configurations. No logic errors, syntax issues, or security concerns were identified.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/attention.py | 5/5 | Adds AttnSoftmaxType enum with vanilla, off-by-one, and learnable softmax variants. Updates fused attention APIs to accept softmax_type and optional softmax_offset parameters. Clean implementation following existing patterns. |
| transformer_engine/jax/flax/transformer.py | 5/5 | Adds softmax_type parameter to attention modules. Implements learnable softmax offset as a trainable parameter with proper sharding. Updates both unfused and fused attention implementations to support sink attention. |
| transformer_engine/jax/cpp_extensions/attention.py | 5/5 | Updates C++ extension wrapper to pass softmax_type and softmax_offset through the FFI boundary. Modifies FusedAttnHelper to include softmax type in backend selection logic. |
| transformer_engine/jax/csrc/extensions/attention.cpp | 4/5 | Refactors C++ implementation to handle softmax_offset tensor in forward/backward passes. Simplifies QKV layout handling by using unified nvte_fused_attn_fwd API instead of separate qkvpacked/kvpacked functions. Adds softmax_offset to aux tensor pack. |
| transformer_engine/jax/softmax.py | 5/5 | Adds AttnSoftmaxType parameter to softmax primitives and modules. Distinguishes between SoftmaxFusion (fused operations) and AttnSoftmaxType (sink attention variants). |
| transformer_engine/jax/flax/module.py | 5/5 | Updates Softmax module to support softmax_type parameter and optional softmax_offset for learnable sink attention. Implements off-by-one and learnable softmax logic. |
| tests/jax/test_fused_attn.py | 5/5 | Adds comprehensive test coverage for all three softmax types (vanilla, off-by-one, learnable). Tests include forward, backward, and gradient checking across multiple configurations. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Flax as Flax Module<br/>(transformer.py)
participant Attn as Attention Layer<br/>(attention.py)
participant CPP as C++ Extension<br/>(attention.py)
participant CUDA as CUDA Kernel<br/>(nvte_fused_attn_fwd)
User->>Flax: Call DotProductAttention
Note over Flax: softmax_type parameter<br/>(vanilla/off_by_one/learnable)
alt Learnable Softmax
Flax->>Flax: Initialize softmax_offset<br/>parameter [1, H, 1, 1]
end
Flax->>Attn: fused_attn(qkv, bias, softmax_offset, ...)
Note over Attn: Pack softmax_type and<br/>softmax_offset into call
Attn->>CPP: fused_attn_fwd(qkv, bias,<br/>softmax_offset, sequence_descriptor)
Note over CPP: Add softmax_offset to<br/>aux_output_tensors
CPP->>CUDA: nvte_fused_attn_fwd(..., softmax_type)
Note over CUDA: Compute attention with<br/>sink softmax variant
alt Off-by-One Softmax
CUDA->>CUDA: S[i] = exp(S[i])/(1 + sum(exp(S)))
else Learnable Softmax
CUDA->>CUDA: S[j,i] = exp(S[j,i])/(exp(alpha[j]) + sum(exp(S[j,:])))
else Vanilla Softmax
CUDA->>CUDA: S[i] = exp(S[i])/sum(exp(S))
end
CUDA-->>CPP: output, softmax_aux, rng_state
CPP-->>Attn: output
alt Training Mode & Learnable Softmax
Attn->>CPP: fused_attn_bwd(dz, ...)
Note over CPP: Compute gradients for<br/>Q, K, V, bias, softmax_offset
CPP->>CUDA: nvte_fused_attn_bwd(...)
CUDA-->>CPP: grad_q, grad_k, grad_v,<br/>grad_bias, grad_softmax_offset
CPP-->>Attn: gradients
Attn-->>Flax: gradients
Note over Flax: Update softmax_offset<br/>parameter via optimizer
end
Attn-->>User: attention output
15 files reviewed, no comments
|
/te-ci jax L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds support for sink attention (attention with learnable softmax variants) to JAX, following the PyTorch implementation from PR #2148. The implementation introduces three softmax types: VANILLA (standard), OFF_BY_ONE (adds 1 to denominator), and LEARNABLE (uses learnable per-head offset parameters).
Key Changes:
- Adds
AttnSoftmaxTypeenum with three variants: VANILLA_SOFTMAX, OFF_BY_ONE_SOFTMAX, LEARNABLE_SOFTMAX - Threads
softmax_typeandsoftmax_offsetparameters throughout the attention pipeline from Flax modules through JAX primitives to C++/cuDNN backends - Renames
SoftmaxTypetoSoftmaxFusionto distinguish kernel fusion strategies from sink attention variants - Updates all attention primitives (standard, context parallel with all-gather, ring attention) to handle softmax_offset
- Implements proper gradient computation for learnable softmax parameters with appropriate all-reduce operations
- Adds comprehensive test coverage including forward/backward passes and distributed scenarios
Implementation Details:
- For LEARNABLE_SOFTMAX, creates a learnable parameter of shape
(1, num_heads, 1, 1)with proper sharding by head dimension - OFF_BY_ONE_SOFTMAX is handled by setting
softmax_offset=1.0 - Context parallel paths (ring attention, all-gather) return dummy gradients for softmax_offset as they don't support learnable variants
- The C++ layer properly packs softmax_offset tensors into the cuDNN tensor pack for both forward and backward passes
- Refactored C++ code consolidates multiple layout-specific calls into unified
nvte_fused_attn_fwd/bwdcalls
Testing:
Test runtime increased from ~3061s to ~3314s (+8%) due to additional test cases covering the three softmax variants
Confidence Score: 4/5
- This PR is generally safe to merge with one potential mathematical issue to verify
- The implementation is comprehensive and well-structured with proper gradient handling, sharding, and test coverage. However, there's one potential mathematical correctness issue in
jax_general_softmax(transformer_engine/jax/cpp_extensions/softmax.py:834-853) where the offset handling needs verification - specifically whether callers pass the raw learnable parameter or its exponential for LEARNABLE_SOFTMAX - transformer_engine/jax/cpp_extensions/softmax.py - verify that the
jax_general_softmaxfunction receives the correct pre-exponentiated offset values for LEARNABLE_SOFTMAX
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/attention.py | 4/5 | Adds AttnSoftmaxType enum with support for VANILLA, OFF_BY_ONE, and LEARNABLE softmax variants. Updates function signatures to accept softmax_type and softmax_offset parameters. Implementation looks solid with proper enum handling and parameter validation. |
| transformer_engine/jax/cpp_extensions/attention.py | 4/5 | Extensive changes to support softmax_offset throughout the attention primitives. Properly handles forward/backward passes, sharding, and context parallelism. The gradient handling for softmax_offset includes proper all-reduce for learnable softmax and dummy returns for CP paths that don't use it. |
| transformer_engine/jax/csrc/extensions/attention.cpp | 4/5 | C++ implementation updated to thread softmax_type and softmax_offset through the call chain. Adds proper tensor pack handling for softmax_offset in both forward and backward passes. The refactoring consolidates multiple layout-specific calls into a unified nvte_fused_attn_fwd call. |
| transformer_engine/jax/cpp_extensions/softmax.py | 3/5 | Adds jax_general_softmax for sink attention support and updates softmax functions to accept softmax_offset. The implementation adds offset to the denominator, but needs verification that callers pass the correct pre-exponentiated values for LEARNABLE_SOFTMAX. |
| transformer_engine/jax/flax/transformer.py | 4/5 | Adds softmax_type parameter to both fused and unfused attention implementations. Creates learnable softmax_offset parameter for LEARNABLE_SOFTMAX with proper sharding. Handles OFF_BY_ONE internally. PRE_SCALE_BIAS handling sets bias to None after adding to prevent double-addition. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Flax as Flax Transformer
participant Attn as JAX Attention
participant Prim as Attention Primitive
participant CPP as C++ Extension
participant cuDNN as cuDNN Backend
User->>Flax: Call attention with softmax_type
alt LEARNABLE_SOFTMAX
Flax->>Flax: Initialize learnable softmax_offset param
else OFF_BY_ONE_SOFTMAX
Flax->>Flax: Set softmax_offset = 1.0
else VANILLA_SOFTMAX
Flax->>Flax: Set softmax_offset = empty
end
Flax->>Attn: fused_attn(qkv, bias, softmax_offset, ...)
Attn->>Attn: Apply sharding constraints to softmax_offset
Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v, bias, softmax_offset, ...)
Prim->>CPP: FusedAttnForwardImpl(q, k, v, bias, softmax_offset, ...)
CPP->>CPP: PrepareFusedAttnForwardAuxTensors (adds softmax_offset to tensor pack)
CPP->>cuDNN: nvte_fused_attn_fwd(tensor_pack with softmax_offset)
cuDNN-->>CPP: output, softmax_aux
CPP-->>Prim: output, softmax_aux, rng_state
Prim-->>Attn: output
Attn-->>Flax: output
Note over User,cuDNN: Backward Pass
User->>Flax: Gradient computation
Flax->>Attn: fused_attn backward
Attn->>Prim: FusedAttnBwdPrimitive.bind(...)
Prim->>CPP: FusedAttnBackwardImpl(..., softmax_offset)
CPP->>cuDNN: nvte_fused_attn_bwd(...)
cuDNN-->>CPP: dq, dk, dv, dbias, dsoftmax_offset
CPP-->>Prim: dq, dk, dv, dbias, dsoftmax_offset
alt LEARNABLE_SOFTMAX
Prim->>Prim: all_reduce dsoftmax_offset across DP/FSDP
end
Prim-->>Attn: gradients
Attn-->>Flax: gradients
Flax-->>User: Updated parameters
15 files reviewed, 1 comment
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds support for sink attention to JAX, implementing three softmax variants: VANILLA, OFF_BY_ONE (zero sink), and LEARNABLE (learnable sink). The implementation follows the pattern established in PR #2148 for PyTorch.
Key Changes
- New
AttnSoftmaxTypeenum: Defines three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE) softmax_offsetparameter: Added throughout the attention pipeline to support sink attention- For LEARNABLE_SOFTMAX: a learnable parameter
[1, num_heads, 1, 1] - For OFF_BY_ONE_SOFTMAX: treated as an implicit offset of 0
- For LEARNABLE_SOFTMAX: a learnable parameter
- Renamed
SoftmaxType→SoftmaxFusion: Distinguishes fusion strategy from softmax variant - Backend support: Updated C++ extensions and cuDNN integration to handle new softmax types
- Comprehensive tests: Added test coverage for all three softmax types
Implementation Quality
Strengths:
- Well-structured changes following existing code patterns
- Comprehensive test coverage with reference implementations
- Proper gradient handling for learnable parameters
- Clean separation between fusion strategy and softmax type
Critical Issue Found:
- Bug in
transformer_engine/jax/flax/module.py:198: OFF_BY_ONE_SOFTMAX incorrectly usessoftmax_offset = 1.0instead ofsoftmax_offset = 0.0. This will produce incorrect attention weights when logits exceed 1.0. The test reference implementation correctly uses a zero logit, but the optimized path has the wrong value.
Confidence Score: 3/5
- This PR should not be merged without fixing the OFF_BY_ONE_SOFTMAX bug
- The implementation is well-structured and comprehensive, but contains a critical logical error in the OFF_BY_ONE_SOFTMAX implementation that will cause incorrect attention computation. The bug is in
module.py:198wheresoftmax_offset = 1.0should besoftmax_offset = 0.0. This discrepancy between the test reference implementation (which correctly uses zero) and the optimized path means tests may pass while the production code produces wrong results in certain cases. transformer_engine/jax/flax/module.py- Fix OFF_BY_ONE_SOFTMAX offset value from 1.0 to 0.0
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/jax/attention.py | 3/5 | Added AttnSoftmaxType enum and softmax_offset parameter throughout attention pipeline. Implementation looks correct. |
| transformer_engine/jax/cpp_extensions/softmax.py | 2/5 | Added jax_general_softmax with offset support. Implementation is mathematically sound but usage in module.py has a bug. |
| transformer_engine/jax/flax/module.py | 2/5 | Updated Softmax module with sink attention support. Critical bug: OFF_BY_ONE_SOFTMAX uses offset=1.0 instead of offset=0.0. |
| transformer_engine/jax/flax/transformer.py | 4/5 | Added softmax_type parameter to attention modules with learnable parameter initialization for LEARNABLE_SOFTMAX. Implementation looks correct. |
| tests/jax/test_fused_attn.py | 5/5 | Comprehensive test coverage added for all three softmax types (VANILLA, OFF_BY_ONE, LEARNABLE). Reference implementation matches expected behavior. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant DPA as DotProductAttention
participant Fused as _FusedDotProductAttention
participant Attn as fused_attn
participant Prim as FusedAttnFwdPrimitive
participant CPP as C++ Backend
participant cuDNN as cuDNN Kernel
Note over User,cuDNN: Sink Attention Flow (JAX)
User->>DPA: __call__(query, key, value, softmax_type='off_by_one')
DPA->>Fused: forward with softmax_type
alt softmax_type == LEARNABLE_SOFTMAX
Fused->>Fused: Initialize learnable param<br/>softmax_offset [1, h, 1, 1]
else softmax_type == OFF_BY_ONE_SOFTMAX
Note over Fused: No offset param needed<br/>(handled by backend)
end
Fused->>Attn: fused_attn(qkv, bias, softmax_offset,<br/>sequence_descriptor, softmax_type)
Attn->>Prim: FusedAttnFwdPrimitive.bind(q, k, v,<br/>bias, softmax_offset, ...)
Prim->>CPP: FusedAttnForwardFFI(q, k, v, bias,<br/>softmax_offset, softmax_type)
CPP->>CPP: Prepare tensor pack:<br/>- softmax_aux<br/>- rng_state<br/>- bias (if provided)<br/>- softmax_offset (if provided)
CPP->>cuDNN: nvte_fused_attn_fwd(q, k, v, bias,<br/>softmax_offset, softmax_type, ...)
alt softmax_type == VANILLA_SOFTMAX
Note over cuDNN: S = exp(QK^T) / sum(exp(QK^T))
else softmax_type == OFF_BY_ONE_SOFTMAX
Note over cuDNN: S = exp(QK^T) / (1 + sum(exp(QK^T)))
else softmax_type == LEARNABLE_SOFTMAX
Note over cuDNN: S = exp(QK^T) / (exp(alpha) + sum(exp(QK^T)))
end
cuDNN-->>CPP: output, softmax_aux, rng_state
CPP-->>Prim: output, softmax_aux, rng_state
Prim-->>Attn: output
Attn-->>Fused: attention output
Fused-->>DPA: attention output
DPA-->>User: final output
Note over User,cuDNN: Backward Pass (if training)
User->>DPA: grad(output)
DPA->>Attn: backward
Attn->>Prim: FusedAttnBwdPrimitive
Prim->>CPP: FusedAttnBackwardFFI
CPP->>cuDNN: nvte_fused_attn_bwd
alt softmax_type == LEARNABLE_SOFTMAX
Note over cuDNN: Compute grad_softmax_offset
cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, grad_softmax_offset
else
cuDNN-->>Prim: grad_q, grad_k, grad_v,<br/>grad_bias, None
end
Prim-->>User: gradients
15 files reviewed, 2 comments
| if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX: | ||
| softmax_offset = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: softmax_offset = 1.0 is incorrect for OFF_BY_ONE_SOFTMAX. Should be softmax_offset = 0.0.
OFF_BY_ONE_SOFTMAX adds +1 to the denominator by including an extra logit position with value 0. The test implementation (test_fused_attn.py:111) confirms this by appending jnp.zeros. The jax_general_softmax function expects offset to be the logit value, not the contribution to the denominator. With offset=0.0, after max-normalization: exp(0 - x_max) is added to denominator, which correctly contributes exp(-x_max). When all terms are normalized by dividing by exp(-x_max), this becomes +1.
| if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX: | |
| softmax_offset = 1.0 | |
| if self.softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX: | |
| softmax_offset = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's good catch. It's not good that tests I ran locally didn't catch it - this is the same tolerance issuance as in PR #2300
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci jax |
Description
PR #2148 added support for sink attention to common and PyTorch. This PR adds support for JAX.
Fixes #2070
Type of change
Checklist: