Skip to content

ONNX export failure for models invoking SDPA attention #28610

@BowenBao

Description

@BowenBao

ValueError: Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument attn_implementation="eager" or pass an attention_mask input when tracing the model.

There has been some discussion about its possible resolutions in the ONNX exporter team. I'd like to post an issue here as well to seek advice and preferences.

  1. Check torch.jit.is_tracing() and fallback to eager attn implementation if needed.
  2. Create attention_mask before passing to SDPA if it is None.
  3. Support SDPA tracing w/o attention_mask (not sure how feasible this is).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions