Skip to content

Native support of torch.nn.functionnal.scaled_dot_product_attention #26557

@younesbelkada

Description

@younesbelkada

Feature request

PyTorch has released torch.nn.functionnal.scaled_dot_product_attention since its 2.0 version that supports more memory efficient attention computation

Official documentation here. Currently three implementations are available in that method, making it possible to dispatch the SDPA kernel to

  • C++ math implementation
  • Flash Attention 1
  • xformers memory efficient attention

In addition to that, in the next versions, PyTorch will add support for Flash Attention 2: pytorch/pytorch#105602 that is already available in the PyTorch nightlies.

SDPA makes model inference faster and more memory efficient, and supports multiple hardwares (CPU, GPU, CUDA, AMD...)

Users can already benefit from SDPA through the BetterTransformer API of optimum

# pip install optimum
model = model.to_bettertransformer()

As SDPA is already quite stable and performant, we should migrate the BetterTransformer API to the native transformers codebase to support OTB model acceleration and memory efficiency.

cc @LysandreJik @fxmarty

Motivation

Make LLMs faster, out of the box by just updating PyTorch version

Your contribution

Help implementing this in the next versions

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions