diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d9e3907fb971..703ebf0d9560 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -40,6 +40,7 @@ logging, replace_return_docstrings, ) +from ...utils.import_utils import is_torch_fx_available from .configuration_llama import LlamaConfig @@ -48,6 +49,12 @@ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig"