ValueError: Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. To solve this issue, please either load your model with the argument attn_implementation="eager" or pass an attention_mask input when tracing the model.
There has been some discussion about its possible resolutions in the ONNX exporter team. I'd like to post an issue here as well to seek advice and preferences.
- Check
torch.jit.is_tracing() and fallback to eager attn implementation if needed.
- Create
attention_mask before passing to SDPA if it is None.
- Support SDPA tracing w/o attention_mask (not sure how feasible this is).