diff --git a/docs/source/en/model_doc/led.mdx b/docs/source/en/model_doc/led.mdx index db6b559bc2d6..63880d874fe9 100644 --- a/docs/source/en/model_doc/led.mdx +++ b/docs/source/en/model_doc/led.mdx @@ -44,8 +44,10 @@ Tips: - LED makes use of *global attention* by means of the `global_attention_mask` (see [`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first `` token. For question answering, it is advised to put *global attention* on all tokens of the question. -- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing - `model.gradient_checkpointing_enable()`. +- To fine-tune LED on all 16384, *gradient checkpointing* can be enabled in case training leads to out-of-memory (OOM) + errors. This can be done by executing `model.gradient_checkpointing_enable()`. + Moreover, the `use_cache=False` + flag can be used to disable the caching mechanism to save memory. - A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing). - A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing). diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 3e852cf2a67d..4b7eb5b3c6ac 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2426,6 +2426,7 @@ def prepare_inputs_for_generation( decoder_input_ids, past=None, attention_mask=None, + global_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, @@ -2443,6 +2444,7 @@ def prepare_inputs_for_generation( "past_key_values": past, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, + "global_attention_mask": global_attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask,