- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
Description
Feature request
PyTorch has released torch.nn.functionnal.scaled_dot_product_attention since its 2.0 version that supports more memory efficient attention computation
Official documentation here. Currently three implementations are available in that method, making it possible to dispatch the SDPA kernel to
- C++ math implementation
- Flash Attention 1
- xformers memory efficient attention
In addition to that, in the next versions, PyTorch will add support for Flash Attention 2: pytorch/pytorch#105602 that is already available in the PyTorch nightlies.
SDPA makes model inference faster and more memory efficient, and supports multiple hardwares (CPU, GPU, CUDA, AMD...)
Users can already benefit from SDPA through the BetterTransformer API of optimum
# pip install optimum
model = model.to_bettertransformer()As SDPA is already quite stable and performant, we should migrate the BetterTransformer API to the native transformers codebase to support OTB model acceleration and memory efficiency.
Motivation
Make LLMs faster, out of the box by just updating PyTorch version
Your contribution
Help implementing this in the next versions