Skip to content

Releases: NVIDIA/TransformerEngine

v1.7

14 Jun 17:58
Compare
Choose a tag to compare

Release Notes – Release 1.7

Key Features and Enhancements

  • [JAX] Added support for SwiGLU, gated/non-gated ReLU, Quick GeLU, and squared ReLU activations.
  • [pyTorch] Added support for attention bias and various QKV formats when using context parallelism.
  • [pyTorch] Expanded the Linear API to handle zero input tokens for MoE-like use cases.
  • [pyTorch] Added support for upstream AMP (torch.amp.autocast) in the checkpoint API.
  • [pyTorch] Added squared-relu activation.
  • [pyTorch] Updated flash-attention support to version 2.5.8.
  • [paddle-paddle] Added support for gradient accumulation fusion.

Fixed Issues

  • [pyTorch] Fixed an uninitialized TP group error that could occur when training with certain tensor parallel configs.
  • [pyTorch] Fixed a bug that occured when loading a checkpoint with calibrated high-precision weights.
  • [pyTorch] Improved the documentation for attention mask.
  • [JAX] Fixed a bug with mismatching shapes of activations and corresponding sharding constraints.
  • [JAX] Fixed an internal bug which caused an incorrect shape to be passed for Layernorm gradient.

Known Issues in This Release

There are no known issues in this release.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

v1.6

13 May 16:36
Compare
Choose a tag to compare

Release Notes – Release 1.6

Key Features and Enhancements

  • [pyTorch] Added a new make_graphed_callables API call for NVIDIA® CUDA® graph capture, including FP8 support.
  • [pyTorch] Added beta support for two boolean arguments in the DelayedScaling FP8 recipe (fp8_dpa and fp8_mha) to support FP8 attention. Note that the API exposure of this feature may change in future releases.

Fixed Issues

  • [pyTorch] Fixed a numerical issue with storing weights in FP8 via the fp8_model_init API call.
  • [pyTorch] Fixed a bug that caused PyTorch modules to use excessive memory when training with frozen weights by storing unnecessary activations for the backward pass.
  • [JAX] Fixed a bug that caused an incorrect shape to be passed for LayerNorm gradient.

Known Issues in This Release

These issues are unchanged from the previous release.

FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue by setting the environment variable MAX_JOBS=1 during Transformer Engine installation.

[pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). In order for Transformer Engine to keep consistent behavior between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

v1.5

17 Apr 17:41
Compare
Choose a tag to compare

Release Notes – Release 1.5

Key Features and Enhancements

  • [pyTorch] Added support for non-reentrant mode for activation recompute in the checkpoint API.
  • [pyTorch] Added support for rectangular matrices in the unfused softmax backend in order to support speculative decoding.
  • [pyTorch] Added the inference_params argument to the DotProductAttention API to support kv-caching.
  • [JAX] Added the DotProductAttention API.
  • [JAX] Expanded RoPE support using the rotary_pos_emb_group_method argument.
  • [paddle] Added support for RMSNorm.
  • [paddle] Added support for RoPE.
  • [paddle] Added support for SwiGLU.

Fixed Issues

  • [pyTorch] Fixed a numerical issue with storing weights in FP8 via the fp8_model_init API.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue either by setting the environment variable MAX_JOBS=1 during Transformer Engine installation.
  • [pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). In order for Transformer Engine to keep consistent behavior between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

  • [JAX] The arguments num_heads, dropout_rate, output_layernorm, apply_residual_connection_post_layernorm, and fuse_qkv are deprecated in the MultiHeadAttention API. They are replaced respectively with num_attention_heads, attention_dropout, input_layernorm, return_layernorm_output, and fused_qkv_params.

Miscellaneous Changes

There are no miscellaneous changes in this release.

v1.4

18 Mar 17:14
Compare
Choose a tag to compare

Release Notes – Release 1.4

Key Features and Enhancements

  • [C/pyTorch] Added support for QuickGELU activation.
  • [C/pyTorch] Added fused RoPE implementation for improved speedup.
  • [C/pyTorch] Added support for zero centered gamma in RMSNorm.
  • [C/pyTorch] Added support for alibi slopes to all attention backends.
  • [docs/pyTorch] Added a tutorial on accelerating HF Llama models with Transformer Engine.
  • [JAX] Added support for sequence parallelism.
  • [JAX] Added support for RoPE.
  • [JAX] Increased execution speed in GQA.
  • [paddle] Added support for grouped query attention (GQA).

Fixed Issues

  • [pyTorch] Fixed an issue where uninitialized/unused module buffers resulted in increased memory usage with the fp8_model_init API call.
  • [pyTorch] Fixed an issue in MultiheadAttention where the attention type was not properly passed down into granular API calls.
  • [pyTorch] Fixed an issue that caused Transformer Engine to crash when used with pyTorch version >= 2.0 and < 2.1.
  • [pyTorch] Fixed a convergence issue when using FP8 with activation recompute.
  • [pyTorch] Fixed a numerical bug associated with use of pipeline parallelism.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue either by setting the environment variable MAX_JOBS=1 during Transformer Engine installation or by installing FlashAttention v1 (e.g. with the command pip install flash-attn==1.0.9) before attempting to install Transformer Engine.
  • [pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). For Transformer Engine to keep consistent behavior between versions and backends, FlashAttention is disabled for the use case “cross attention with casual masking” when 2.1+ version of FlashAttentionA is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous Changes

FlashAttention v1 is not longer supported in Transformer Engine. Support for it was dropped in version 1.3. The minimum required FlashAttention version is v2.0.6.

v1.3

26 Feb 22:11
Compare
Choose a tag to compare

Release Notes – Release 1.3

Key Features and Enhancements

  • [pyTorch] Added support for deferred parameter initialization in several Transformer Engine modules via the device="meta" parameter:
    Linear
    LayerNorm
    RMSNorm
    LayerNormLinear
    LayerNormMLP
    MultiheadAttention
    TransformerLayer
  • [pyTorch] Added support for CPU offloading of weights and activations for tensors saved for the backward pass for additional memory savings.
  • [pyTorch] Added an additional attn_input_format parameter to TransformerLayer for the layout of the QKV tensor.
  • [pyTorch] Added support for non-tensor values of the forward parameter when using the checkpoint API call.
  • [PaddlePaddle] Added support for sequence parallelism.
  • [PaddlePaddle] Optimized memory usage for pipeline parallel training.
  • [JAX] Added support for grouped query attention (GQA).

Fixed Issues

  • [pyTorch] In LayerNormLinear and Linear, unused copies of weight and bias tensors were not deleted for the case when Q, K, and V tensors are fused.
  • [pyTorch] Faulty usage of pipeline parallelism with the FusedAttention backend.
  • [pyTorch] attention_type was not correctly passed from the MultiheadAttention call to the DotProductAttention call.
  • [pyTorch] Fused DPA backend reported bogus NaN errors during the backward pass.
  • [pyTorch] Crashes when running with PyTorch v2.0.1.
  • [pyTorch] Statistics could be computed incorrectly when training with FP8 in recent versions of pyTorch. For details see #600.
  • [JAX] Crashes when training in FP8 + FSDP.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue by setting the environment variable MAX_JOBS=1 during Transformer Engine installation.
  • [pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). In order for Transformer Engine to keep the consistent behavior between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

Miscellaneous Changes

FlashAttention v1 is no longer supported in Transformer Engine. The minimum required version is v2.0.6.

v1.2.1

22 Jan 17:27
Compare
Choose a tag to compare

Release Notes – Release 1.2.1

Fixed Issues

  • Statistics could be computed incorrectly when training with FP8 in recent versions of pyTorch. For details see #600.

v1.2

11 Jan 17:54
Compare
Choose a tag to compare

Release Notes – Release 1.2.0

Key Features and Enhancements

  • [pyTorch] Sliding window support is added for DotProductAttention.
  • [pyTorch] Performance of DotProductAttention is increased on Hopper GPUs by utilizing cuDNN.
  • [pyTorch] Support for the Falcon architecture is added in TransformerLayer via the new option parallel_attention_mlp.
  • [pyTorch] Checkpointing logic when using fp8_model_init is improved.
  • [JAX] Support is added for controlling SM margin in LayerNorm and RMSNorm kernel via environment variables NVTE_FWD_LAYERNORM_SM_MARGIN and NVTE_BWD_LAYERNORM_SM_MARGIN.

Fixed Issues

  • Weight gradient could be computed incorrectly in some cases when FP8 execution and sequence parallelism were used together.
  • Statistics were computed incorrectly during FP8 calibration.
  • Using torch.compile on DotProductAttention module caused a crash.
  • Rotary embeddings during pipeline-parallel inference did not operate correctly.
  • Incorrect mask type used by the decoder in encoder-decoder architectures.
  • Exporting Transformer Engine modules to ONNX in recent versions of pyTorch did not work correctly.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). You can work around this issue either by setting the environment variable MAX_JOBS=1 during Transformer Engine installation, or by installing FlashAttention v1 (e.g. by running pip install flash-attn==1.0.9) before attempting to install Transformer Engine.
  • [pyTorch] FlashAttention v2.1 changed the behavior of the causal mask when performing cross-attention. (See https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference.) To keep Transformer Engine behavior consistent between versions and backends, FlashAttention is disabled for this use case (cross attention with casual masking) when 2.1+ version of FlashAttention is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

v1.1

07 Dec 23:44
Compare
Choose a tag to compare

Release Notes – Release 1.1.0

Key Features and Enhancements

  • [pyTorch] Memory usage is reduced when using the fp8_model_init API during inference.
  • [pyTorch] Memory usage is reduced when using the LayerNormLinear, LayernormMLP, and TransformerLayer APIs.
  • [JAX] Transformer Engine is migrated to the new Custom Partitioning mechanism of parallelism for custom ops in JAX.
  • [JAX] The attention operation’s performance is improved when using cuDNN version 8.9.6 or greater.
  • [C/C++] Transformer Engine can now be built as a subproject.

Fixed Issues

  • Fixed an issue where in some cases passing the non-contiguous tensors as Q, K, or V to DotProductAttention would result in an error, “Exception: The provided qkv memory layout is not supported!.”

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). One could workaround this issue by either setting the MAX_JOBS=1 environment variable during Transformer Engine installation or installing FlashAttention v1 (e.g. by pip install flash-attn==1.0.9) before attempting to install Transformer Engine.
  • [pyTorch] FlashAttention v2.1 has changed the behavior of the causal mask when performing cross-attention (see https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag for reference). For Transformer Engine to preserve consistent behavior between versions and back ends, FlashAttention is disabled for this use case (i.e. cross-attention with casual masking) when FlashAttention version 2.1+ is installed.

Breaking Changes in This Release

There are no breaking changes in this release.

Deprecated Features

There are no deprecated features in this release.

v1.0

09 Nov 19:14
Compare
Choose a tag to compare

Release Notes – Release 1.0.0

Key Features and Enhancements

  • [pyTorch] Expanded the support for different layouts in DotProductAttention.

  • [pyTorch] Added support for packed input for the FlashAttention backend of DotProductAttention.

  • [pyTorch] Better support for the KV cache during inference via the new InferenceParams class

  • [pyTorch] Better support for parallel state handling for model parallelism via the new CUDARNGStatesTracker class

  • [pyTorch] Added an experimental support for the FP8 Tensor type and a new context manager fp8_model_init. When enabled, Transformer Engine modules created inside this fp8_model_init region will hold only FP8 copies of its parameters, as opposed to the default behavior where both higher precision and FP8 copies are present. This may result in lower memory consumption and is especially useful for scenarios like:

    • full model training using optimizer with master weights, where the high precision copies of weights are already present in the optimizer.
    • inference, where only the FP8 copies of the parameters are used.
    • LoRA-like fine-tuning, where the main parameters of the model do not change.
  • [JAX] Added an ability to set dropout rate for the activation output in LayerNormMLP.

  • [Paddle] Added documentation.

Fixed Issues

  • [pyTorch] Multiple fixes for activation recomputation when using FP8.
  • [pyTorch] Multiple fixes specific to the usage of Transformer Engine by Megatron-LM and NeMo.
  • [pyTorch] Fixed a crash occuring when trying to use LayerNormLinear with the return_layernorm_output option set.
  • [pyTorch] Fixes to the ONNX export of the attention layer.
  • [pyTorch] Fixed a crash happening when using RoPE.
  • [JAX] Fixed a crash occuring in some cases when using cross attention with FSDP.
  • [JAX] Fixed the wrong handling of the FP8 scaling factor.

Known Issues in This Release

  • FlashAttention v2, which is a dependency of this release of Transformer Engine, has a known issue with excessive memory usage during installation (Dao-AILab/flash-attention#358). One could workaround this issue by either setting the MAX_JOBS=1 environment variable during Transformer Engine installation or installing FlashAttention v1 (e.g. by pip install flash-attn==1.0.9) before attempting to install Transformer Engine.
  • [pyTorch] In some cases passing the non-contiguous tensors as Q, K or V to DotProductAttention may result in an error Exception: The provided qkv memory layout is not supported! It will be fixed in a future release. In the meantime, the workaround is to call .contiguous() on those tensors before passing them to DotProductAttention.

Breaking Changes in This Release

  • The experimental support for TensorFlow has been removed.
  • [pyTorch] The deprecated TransformerLayer arguments attention_softmax_in_fp32 and apply_query_key_layer_scaling were removed.
  • [pyTorch] Deprecated argument skip_weight_param_allocation in the Linear and LayerNormLinear API has been removed. Consequently, the weight and bias arguments in the forward method of those APIs have also been removed.
  • [pyTorch] Support for loading old/deprecated checkpoint formats where the extra states for FP8 are not serialized into BytesIO or torch.Tensor objects has been removed.
  • [JAX] Deprecated modules and functions DenseGeneral, LayerNorm, LayerNormDenseGeneral, LayerNormMLP, TransformerEngineBase, extend_logical_axis_rules, MultiHeadAttention, RelativePositionBiases, TransformerLayer, and TransformerLayerType have been removed from transformer_engine.jax and must now only be imported from transformer_engine.jax.flax.

Deprecated Features

There are no deprecated features in this release.