Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/source/en/model_doc/led.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ Tips:
- LED makes use of *global attention* by means of the `global_attention_mask` (see
[`LongformerModel`]). For summarization, it is advised to put *global attention* only on the first
`<s>` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
`model.gradient_checkpointing_enable()`.
- To fine-tune LED on all 16384, *gradient checkpointing* can be enabled in case training leads to out-of-memory (OOM)
errors. This can be done by executing `model.gradient_checkpointing_enable()`.
Moreover, the `use_cache=False`
flag can be used to disable the caching mechanism to save memory.
- A notebook showing how to evaluate LED, can be accessed [here](https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing).
- A notebook showing how to fine-tune LED, can be accessed [here](https://colab.research.google.com/drive/12LjJazBl7Gam0XBPy_y0CTOJZeZ34c2v?usp=sharing).

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,6 +2426,7 @@ def prepare_inputs_for_generation(
decoder_input_ids,
past=None,
attention_mask=None,
global_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
Expand All @@ -2443,6 +2444,7 @@ def prepare_inputs_for_generation(
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"global_attention_mask": global_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
Expand Down