From 5b3c24948e754f66e187ae8009eed25b651780dd Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 1 Feb 2024 19:27:14 +0000 Subject: [PATCH] Unblock Llama2 ONNX export w/ sdpa by falling back to manual impl --- .../models/llama/modeling_llama.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 90706cae68c0..98fcde5886ff 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -673,12 +673,22 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + _jit_tracing = torch.jit.is_tracing() + _fallback_to_manual = _jit_tracing or output_attentions + # TODO: Improve these warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + if _jit_tracing: + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but SDPA can not be traced with torch.jit.trace when no attention_mask is provided. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) + + if _fallback_to_manual: return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, @@ -1029,9 +1039,11 @@ def forward( if self._use_flash_attention_2: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. + elif self._use_sdpa and not output_attentions and not torch.jit.is_tracing(): + # NOTE: Fallback to eager attention in any of the following cases: + # - output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # - Attention using SDPA can not be traced with torch.jit.trace when no attention_mask is provided. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, (batch_size, seq_length),