diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1eafed4131..42c9451245 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -197,6 +197,7 @@ def __call__( fused_scale_factor = scale_factor if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias + bias = None def apply_swa_mask(original_mask: Array) -> Array: """Apply the sliding window mask to a given mask"""