diff --git a/docs/source/model_doc/led.rst b/docs/source/model_doc/led.rst index 2e05163d37b4..1eaa9e325ffa 100644 --- a/docs/source/model_doc/led.rst +++ b/docs/source/model_doc/led.rst @@ -46,8 +46,8 @@ Tips: - LED makes use of *global attention* by means of the ``global_attention_mask`` (see :class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first ```` token. For question answering, it is advised to put *global attention* on all tokens of the question. -- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting - ``config.gradient_checkpointing = True``. +- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing + ``model.gradient_checkpointing_enable()``. - A notebook showing how to evaluate LED, can be accessed `here `__. - A notebook showing how to fine-tune LED, can be accessed `here diff --git a/docs/source/performance.md b/docs/source/performance.md index 4f479d857569..c3239f3b0c08 100644 --- a/docs/source/performance.md +++ b/docs/source/performance.md @@ -53,6 +53,7 @@ Software: - Tensor Parallelism - Low-memory Optimizers - fp16/bf16 (smaller data) +- Gradient checkpointing @@ -226,6 +227,21 @@ pytorch `autocast` which performs AMP include a caching feature, which speed thi Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.) + +### Gradient Checkpointing + +One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation. + +This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers. + +To activate this feature in 🤗 Transformers for models that support it, use: + +```python +model.gradient_checkpointing_enable() +``` +or add `--gradient_checkpointing` to the Trainer arguments. + + ### Batch sizes One gets the most efficient performance when batch sizes and input/output neuron counts are divisible by a certain number, which typically starts at 8, but can be much higher as well. That number varies a lot depending on the specific hardware being used and the dtype of the model. diff --git a/examples/pytorch/language-modeling/README.md b/examples/pytorch/language-modeling/README.md index 23989d7ed1a0..c768f5ec31bb 100644 --- a/examples/pytorch/language-modeling/README.md +++ b/examples/pytorch/language-modeling/README.md @@ -174,8 +174,3 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides=" ``` This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`. - -This feature can also be used to activate gradient checkpointing by passing: -``` ---config_overrides "gradient_checkpointing=true,use_cache=False" -``` diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 45683ac801a3..bc3ecf77ba1c 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -19,6 +19,7 @@ import copy import json import os +import warnings from typing import Any, Dict, Tuple, Union from . import __version__ @@ -330,6 +331,14 @@ def __init__(self, **kwargs): # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) + # Deal with gradient checkpointing + if "gradient_checkpointing" in kwargs: + warnings.warn( + "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " + "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " + "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." + ) + # Additional attributes without default values for key, value in kwargs.items(): try: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e82d0ad9e31f..21a1b09f30d6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -20,6 +20,7 @@ import warnings from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch @@ -450,6 +451,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix _keys_to_ignore_on_save = None is_parallelizable = False + supports_gradient_checkpointing = False @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -469,6 +471,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): # Save config and origin of the pretrained weights if given in model self.config = config self.name_or_path = config.name_or_path + if getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + # Remove the attribute now that is has been consumed, so it's no saved in the config. + delattr(self.config, "gradient_checkpointing") @classmethod def _from_config(cls, config, **kwargs): @@ -932,6 +938,27 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]): self.base_model._prune_heads(heads_to_prune) + def gradient_checkpointing_enable(self, flag: bool = True): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def gradient_checkpointing_disable(self, flag: bool = True): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index e26afb2ab4b6..6efbe4ca510a 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -82,8 +82,6 @@ class BartConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -131,7 +129,6 @@ def __init__( init_std=0.02, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, use_cache=True, num_labels=3, pad_token_id=1, @@ -161,7 +158,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a466be30a688..134669cee4ba 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -471,6 +471,7 @@ def forward(self, hidden_states: torch.Tensor): class BartPretrainedModel(PreTrainedModel): config_class = BartConfig base_model_prefix = "model" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): @@ -484,6 +485,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BartDecoder, BartEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -687,6 +692,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -782,7 +788,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -849,6 +855,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1020,12 +1027,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/beit/configuration_beit.py b/src/transformers/models/beit/configuration_beit.py index 08ecc6064697..d31f83dd3a5e 100644 --- a/src/transformers/models/beit/configuration_beit.py +++ b/src/transformers/models/beit/configuration_beit.py @@ -57,8 +57,6 @@ class BeitConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 236551d27cc6..1ad3fcd1e6d1 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -432,6 +432,7 @@ def __init__(self, config, window_size=None): for i in range(config.num_hidden_layers) ] ) + self.gradient_checkpointing = False def forward( self, @@ -450,7 +451,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -494,6 +495,7 @@ class BeitPreTrainedModel(PreTrainedModel): config_class = BeitConfig base_model_prefix = "beit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -511,6 +513,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BeitEncoder): + module.gradient_checkpointing = value + BEIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/bert/configuration_bert.py b/src/transformers/models/bert/configuration_bert.py index 8359f0c3b7e2..861cdfbc8ea6 100644 --- a/src/transformers/models/bert/configuration_bert.py +++ b/src/transformers/models/bert/configuration_bert.py @@ -92,8 +92,6 @@ class BertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -137,7 +135,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, @@ -157,7 +154,6 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache self.classifier_dropout = classifier_dropout diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index f02d67a31a21..ecb0d184a4a4 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -529,6 +529,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -555,12 +556,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -714,6 +714,7 @@ class BertPreTrainedModel(PreTrainedModel): config_class = BertConfig load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -732,6 +733,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + @dataclass class BertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/bert_generation/configuration_bert_generation.py b/src/transformers/models/bert_generation/configuration_bert_generation.py index 54659f4394a5..2284f873e708 100644 --- a/src/transformers/models/bert_generation/configuration_bert_generation.py +++ b/src/transformers/models/bert_generation/configuration_bert_generation.py @@ -52,8 +52,6 @@ class BertGenerationConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -96,7 +94,6 @@ def __init__( pad_token_id=0, bos_token_id=2, eos_token_id=1, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, **kwargs @@ -114,6 +111,5 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/big_bird/configuration_big_bird.py b/src/transformers/models/big_bird/configuration_big_bird.py index e6fdfd1d14cd..85dd8de7dd9a 100644 --- a/src/transformers/models/big_bird/configuration_big_bird.py +++ b/src/transformers/models/big_bird/configuration_big_bird.py @@ -82,8 +82,6 @@ class BigBirdConfig(PretrainedConfig): num_random_blocks (:obj:`int`, `optional`, defaults to 3) Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type == "block_sparse"`. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. classifier_dropout (:obj:`float`, `optional`): The dropout ratio for the classification head. @@ -127,7 +125,6 @@ def __init__( rescale_embeddings=False, block_size=64, num_random_blocks=3, - gradient_checkpointing=False, classifier_dropout=None, **kwargs ): @@ -153,7 +150,6 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.is_encoder_decoder = is_encoder_decoder - self.gradient_checkpointing = gradient_checkpointing self.rescale_embeddings = rescale_embeddings self.attention_type = attention_type diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index f7d0d857bc45..84a428591e69 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1555,6 +1555,7 @@ def __init__(self, config): self.layer = nn.ModuleList( [BigBirdLayer(config, seed=layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def set_attention_type(self, value: str): if value not in ["original_full", "block_sparse"]: @@ -1598,12 +1599,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -1756,6 +1756,7 @@ class BigBirdPreTrainedModel(PreTrainedModel): config_class = BigBirdConfig load_tf_weights = load_tf_weights_in_big_bird base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -1774,6 +1775,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BigBirdEncoder): + module.gradient_checkpointing = value + BIG_BIRD_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py index 28211c9b164f..297e2cede4da 100644 --- a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py @@ -94,8 +94,6 @@ class BigBirdPegasusConfig(PretrainedConfig): "block_sparse"`. scale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`) Whether to rescale embeddings with (hidden_size ** 0.5). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -141,7 +139,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=2, eos_token_id=1, @@ -170,7 +167,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True # extra config diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 2fd765eb5dd0..536cd784daaf 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1567,6 +1567,7 @@ def forward(self, hidden_states: torch.Tensor): class BigBirdPegasusPreTrainedModel(PreTrainedModel): config_class = BigBirdPegasusConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1579,6 +1580,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1764,6 +1769,7 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1894,7 +1900,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2054,6 +2060,7 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -2225,12 +2232,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot/configuration_blenderbot.py b/src/transformers/models/blenderbot/configuration_blenderbot.py index c2b272af034e..13acbdf699aa 100644 --- a/src/transformers/models/blenderbot/configuration_blenderbot.py +++ b/src/transformers/models/blenderbot/configuration_blenderbot.py @@ -78,8 +78,6 @@ class BlenderbotConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=1, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -155,7 +152,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index e6bc6f657141..11e866594a37 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -451,6 +451,7 @@ def forward( class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -463,6 +464,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -644,6 +649,7 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -738,7 +744,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -808,6 +814,7 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -980,12 +987,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py index de8927a4ffe9..0f76e2e3ae0e 100644 --- a/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/configuration_blenderbot_small.py @@ -78,8 +78,6 @@ class BlenderbotSmallConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=1, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -154,7 +151,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 81188488fe0a..a15c8276c377 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -449,6 +449,7 @@ def forward( class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -461,6 +462,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -645,6 +650,7 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -740,7 +746,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -808,6 +814,7 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -981,12 +988,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/canine/configuration_canine.py b/src/transformers/models/canine/configuration_canine.py index 3feef5ac75be..79be54a8247b 100644 --- a/src/transformers/models/canine/configuration_canine.py +++ b/src/transformers/models/canine/configuration_canine.py @@ -61,8 +61,6 @@ class CanineConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. downsampling_rate (:obj:`int`, `optional`, defaults to 4): The rate at which to downsample the original character sequence length before applying the deep Transformer encoder. diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index 18ca01031cc4..a13505d3a052 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -772,6 +772,7 @@ def __init__( for _ in range(config.num_hidden_layers) ] ) + self.gradient_checkpointing = False def forward( self, @@ -791,7 +792,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -895,6 +896,7 @@ class CaninePreTrainedModel(PreTrainedModel): config_class = CanineConfig load_tf_weights = load_tf_weights_in_canine base_model_prefix = "canine" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -913,6 +915,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CanineEncoder): + module.gradient_checkpointing = value + CANINE_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/clip/configuration_clip.py b/src/transformers/models/clip/configuration_clip.py index b82428871195..0f8b6fa9a4de 100644 --- a/src/transformers/models/clip/configuration_clip.py +++ b/src/transformers/models/clip/configuration_clip.py @@ -68,8 +68,6 @@ class CLIPTextConfig(PretrainedConfig): initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -103,7 +101,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - gradient_checkpointing=False, **kwargs ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -120,7 +117,6 @@ def __init__( self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout - self.gradient_checkpointing = gradient_checkpointing class CLIPVisionConfig(PretrainedConfig): @@ -161,8 +157,6 @@ class CLIPVisionConfig(PretrainedConfig): initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -194,7 +188,6 @@ def __init__( attention_dropout=0.0, initializer_range=0.02, initializer_factor=1.0, - gradient_checkpointing=False, **kwargs ): super().__init__(**kwargs) @@ -211,7 +204,6 @@ def __init__( self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act - self.gradient_checkpointing = gradient_checkpointing class CLIPConfig(PretrainedConfig): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 8d723e05fc29..4f3b280a1bc5 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -338,6 +338,7 @@ class CLIPPreTrainedModel(PreTrainedModel): config_class = CLIPConfig base_model_prefix = "clip" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -383,6 +384,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, CLIPEncoder): + module.gradient_checkpointing = value + CLIP_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use @@ -499,6 +504,7 @@ def __init__(self, config: CLIPConfig): super().__init__() self.config = config self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -551,7 +557,7 @@ def forward( for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/convbert/modeling_convbert.py b/src/transformers/models/convbert/modeling_convbert.py index fbd0cdfc5ec7..99d8ae5dd4cf 100755 --- a/src/transformers/models/convbert/modeling_convbert.py +++ b/src/transformers/models/convbert/modeling_convbert.py @@ -248,6 +248,7 @@ class ConvBertPreTrainedModel(PreTrainedModel): config_class = ConvBertConfig load_tf_weights = load_tf_weights_in_convbert base_model_prefix = "convbert" + supports_gradient_checkpointing = True authorized_missing_keys = [r"position_ids"] authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"] @@ -267,6 +268,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ConvBertEncoder): + module.gradient_checkpointing = value + class SeparableConv1D(nn.Module): """This class implements separable convolution, i.e. a depthwise and a pointwise layer""" @@ -603,6 +608,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -624,7 +630,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/deit/configuration_deit.py b/src/transformers/models/deit/configuration_deit.py index 0bbbff709b83..98bbe1b01ba8 100644 --- a/src/transformers/models/deit/configuration_deit.py +++ b/src/transformers/models/deit/configuration_deit.py @@ -58,8 +58,6 @@ class DeiTConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index b848376817b1..6ffa6afa3a5b 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -324,6 +324,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -342,7 +343,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -384,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel): config_class = DeiTConfig base_model_prefix = "deit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -401,6 +403,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DeiTEncoder): + module.gradient_checkpointing = value + DEIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 3061addadaed..af650e75e1a6 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -783,6 +783,7 @@ def forward(self, hidden_states: torch.Tensor): class DetrPreTrainedModel(PreTrainedModel): config_class = DetrConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -807,6 +808,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, DetrDecoder): + module.gradient_checkpointing = value + DETR_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -997,6 +1002,7 @@ def __init__(self, config: DetrConfig): self.layernorm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1084,7 +1090,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): continue - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/dpr/configuration_dpr.py b/src/transformers/models/dpr/configuration_dpr.py index 2773835f721c..a9b5f96556c7 100644 --- a/src/transformers/models/dpr/configuration_dpr.py +++ b/src/transformers/models/dpr/configuration_dpr.py @@ -69,8 +69,6 @@ class DPRConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -99,7 +97,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", projection_dim: int = 0, **kwargs @@ -118,6 +115,5 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.projection_dim = projection_dim self.position_embedding_type = position_embedding_type diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 37fa61b70651..c1a3fa618d4e 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -30,7 +30,7 @@ from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import logging -from ..bert.modeling_bert import BertModel +from ..bert.modeling_bert import BertEncoder, BertModel from .configuration_dpr import DPRConfig @@ -300,6 +300,10 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): def init_weights(self): self.question_encoder.init_weights() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + class DPRPretrainedReader(PreTrainedModel): """ @@ -317,6 +321,10 @@ def init_weights(self): self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights) self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BertEncoder): + module.gradient_checkpointing = value + ############### # Actual Models diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 867a7c091513..1f44b23522b7 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -527,6 +527,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -553,12 +554,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -663,6 +663,7 @@ class ElectraPreTrainedModel(PreTrainedModel): config_class = ElectraConfig load_tf_weights = load_tf_weights_in_electra base_model_prefix = "electra" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] @@ -683,6 +684,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ElectraEncoder): + module.gradient_checkpointing = value + @dataclass class ElectraForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/fnet/configuration_fnet.py b/src/transformers/models/fnet/configuration_fnet.py index 047190a3ed24..a6922f835588 100644 --- a/src/transformers/models/fnet/configuration_fnet.py +++ b/src/transformers/models/fnet/configuration_fnet.py @@ -64,8 +64,6 @@ class FNetConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. use_tpu_fourier_optimizations (:obj:`bool`, `optional`, defaults to :obj:`False`): Determines whether to use TPU optimized FFTs. If :obj:`True`, the model will favor axis-wise FFTs transforms. Set to :obj:`False` for GPU/CPU hardware, in which case n-dimensional FFTs are used. diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 2a1b7f5f2ab2..9340eb04f3c4 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -284,6 +284,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward(self, hidden_states, output_hidden_states=False, return_dict=True): all_hidden_states = () if output_hidden_states else None @@ -292,7 +293,7 @@ def forward(self, hidden_states, output_hidden_states=False, return_dict=True): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -413,6 +414,7 @@ class FNetPreTrainedModel(PreTrainedModel): config_class = FNetConfig base_model_prefix = "fnet" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -432,6 +434,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, FNetEncoder): + module.gradient_checkpointing = value + @dataclass class FNetForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/gpt2/configuration_gpt2.py b/src/transformers/models/gpt2/configuration_gpt2.py index f003023ca8b0..41120c94daad 100644 --- a/src/transformers/models/gpt2/configuration_gpt2.py +++ b/src/transformers/models/gpt2/configuration_gpt2.py @@ -108,9 +108,7 @@ class GPT2Config(PretrainedConfig): The dropout ratio to be used after the projection and activation. scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): - Scale attention weights by dividing by sqrt(hidden_size). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. + Scale attention weights by dividing by sqrt(hidden_size).. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). @@ -158,7 +156,6 @@ def __init__( summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -182,7 +179,6 @@ def __init__( self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels - self.gradient_checkpointing = gradient_checkpointing self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 43419e66151d..d6fab7f7ff24 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -374,6 +374,7 @@ class GPT2PreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_gpt2 base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -394,6 +395,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPT2Model): + module.gradient_checkpointing = value + @dataclass class GPT2DoubleHeadsModelOutput(ModelOutput): @@ -589,6 +594,7 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -764,12 +770,11 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/gpt_neo/configuration_gpt_neo.py b/src/transformers/models/gpt_neo/configuration_gpt_neo.py index e5b7e683d99a..d5069fb01711 100644 --- a/src/transformers/models/gpt_neo/configuration_gpt_neo.py +++ b/src/transformers/models/gpt_neo/configuration_gpt_neo.py @@ -79,8 +79,6 @@ class GPTNeoConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -120,7 +118,6 @@ def __init__( summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=0.1, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -144,7 +141,6 @@ def __init__( self.summary_activation = summary_activation self.summary_first_dropout = summary_first_dropout self.summary_proj_to_labels = summary_proj_to_labels - self.gradient_checkpointing = gradient_checkpointing self.use_cache = use_cache self.bos_token_id = bos_token_id diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 353d3b0fb6ce..3fafd75ac21a 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -361,6 +361,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): config_class = GPTNeoConfig load_tf_weights = load_tf_weights_in_gpt_neo base_model_prefix = "transformer" + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -381,6 +382,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTNeoModel): + module.gradient_checkpointing = value + GPT_NEO_START_DOCSTRING = r""" @@ -482,6 +487,7 @@ def __init__(self, config): self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.wte @@ -592,12 +598,11 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/gptj/configuration_gptj.py b/src/transformers/models/gptj/configuration_gptj.py index 93018fdcb60b..61dfd4e66393 100644 --- a/src/transformers/models/gptj/configuration_gptj.py +++ b/src/transformers/models/gptj/configuration_gptj.py @@ -68,8 +68,6 @@ class GPTJConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): Scale attention weights by dividing by sqrt(hidden_size). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). @@ -111,7 +109,6 @@ def __init__( layer_norm_epsilon=1e-5, initializer_range=0.02, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, @@ -131,7 +128,6 @@ def __init__( self.attn_pdrop = attn_pdrop self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range - self.gradient_checkpointing = gradient_checkpointing self.scale_attn_weights = scale_attn_weights self.use_cache = use_cache diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 2d7781a2758b..a23da0834711 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -303,6 +303,7 @@ class GPTJPreTrainedModel(PreTrainedModel): config_class = GPTJConfig base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -323,6 +324,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GPTJModel): + module.gradient_checkpointing = value + GPTJ_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use @@ -445,6 +450,7 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -598,12 +604,11 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/hubert/configuration_hubert.py b/src/transformers/models/hubert/configuration_hubert.py index 633807684fcc..624211431c55 100644 --- a/src/transformers/models/hubert/configuration_hubert.py +++ b/src/transformers/models/hubert/configuration_hubert.py @@ -120,8 +120,6 @@ class HubertConfig(PretrainedConfig): instance of :class:`~transformers.HubertForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -172,7 +170,6 @@ def __init__( ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -203,7 +200,6 @@ def __init__( self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm - self.gradient_checkpointing = gradient_checkpointing self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 95d5c91f5ae8..6575f4932b91 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -525,6 +525,7 @@ def __init__(self, config): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -564,7 +565,7 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -612,6 +613,7 @@ def __init__(self, config): self.layers = nn.ModuleList( [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def forward( self, @@ -651,7 +653,7 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -698,6 +700,7 @@ class HubertPreTrainedModel(PreTrainedModel): config_class = HubertConfig base_model_prefix = "hubert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -725,6 +728,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 61c775d9fff8..d4f74ff47e7d 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -579,17 +579,13 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: - raise NotImplementedError("gradient checkpointing is not currently supported") - - else: - layer_outputs = layer_module( - hidden_states, - hidden_states_scaling_factor, - attention_mask, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module( + hidden_states, + hidden_states_scaling_factor, + attention_mask, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/layoutlm/configuration_layoutlm.py b/src/transformers/models/layoutlm/configuration_layoutlm.py index ee9a10e82451..a8dac8cd9dfe 100644 --- a/src/transformers/models/layoutlm/configuration_layoutlm.py +++ b/src/transformers/models/layoutlm/configuration_layoutlm.py @@ -66,8 +66,6 @@ class LayoutLMConfig(BertConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024): The maximum value that the 2D position embedding might ever used. Typically set this to something large just in case (e.g., 1024). @@ -103,7 +101,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, max_2d_position_embeddings=1024, **kwargs ): @@ -121,7 +118,6 @@ def __init__( initializer_range=initializer_range, layer_norm_eps=layer_norm_eps, pad_token_id=pad_token_id, - gradient_checkpointing=gradient_checkpointing, **kwargs, ) self.max_2d_position_embeddings = max_2d_position_embeddings diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 3e7dfe8560c7..b47d2793d141 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -442,6 +442,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -468,12 +469,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -609,6 +609,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): config_class = LayoutLMConfig pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlm" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -627,6 +628,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LayoutLMEncoder): + module.gradient_checkpointing = value + LAYOUTLM_START_DOCSTRING = r""" The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index e42d77bab262..6c42ce1ccc9a 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -378,6 +378,8 @@ def __init__(self, config): self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) + self.gradient_checkpointing = False + def _calculate_1d_position_embeddings(self, hidden_states, position_ids): rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) rel_pos = relative_position_bucket( @@ -443,7 +445,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -502,6 +504,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): config_class = LayoutLMv2Config pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlmv2" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -520,6 +523,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LayoutLMv2Encoder): + module.gradient_checkpointing = value + def my_convert_sync_batchnorm(module, process_group=None): # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d` diff --git a/src/transformers/models/led/configuration_led.py b/src/transformers/models/led/configuration_led.py index 5992d275ed9f..e30c3e04c4f5 100644 --- a/src/transformers/models/led/configuration_led.py +++ b/src/transformers/models/led/configuration_led.py @@ -82,8 +82,6 @@ class LEDConfig(PretrainedConfig): https://arxiv.org/abs/1909.11556>`__ for more details. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models) - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -132,7 +130,6 @@ def __init__( pad_token_id=1, bos_token_id=0, eos_token_id=2, - gradient_checkpointing=False, attention_window: Union[List[int], int] = 512, **kwargs ): @@ -157,7 +154,6 @@ def __init__( self.use_cache = use_cache self.num_hidden_layers = encoder_layers self.attention_window = attention_window - self.gradient_checkpointing = gradient_checkpointing super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index c1c5af6d1ec1..926da161a97d 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1077,6 +1077,7 @@ def forward(self, hidden_states: torch.Tensor): class LEDPreTrainedModel(PreTrainedModel): config_class = LEDConfig base_model_prefix = "led" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -1089,6 +1090,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LEDDecoder, LEDEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -1625,6 +1630,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) @@ -1809,7 +1815,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1894,6 +1900,7 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[nn.Embedding] = Non self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -2061,12 +2068,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 6fbdfb12f57c..3e327c5c688e 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1231,6 +1231,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -1259,7 +1260,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1363,6 +1364,7 @@ class LongformerPreTrainedModel(PreTrainedModel): config_class = LongformerConfig base_model_prefix = "longformer" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -1381,6 +1383,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LongformerEncoder): + module.gradient_checkpointing = value + LONGFORMER_START_DOCSTRING = r""" diff --git a/src/transformers/models/luke/configuration_luke.py b/src/transformers/models/luke/configuration_luke.py index befd3e45e5de..ba6dc4964386 100644 --- a/src/transformers/models/luke/configuration_luke.py +++ b/src/transformers/models/luke/configuration_luke.py @@ -68,8 +68,6 @@ class LukeConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`): Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.) @@ -106,7 +104,6 @@ def __init__( type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, - gradient_checkpointing=False, use_entity_aware_attention=True, pad_token_id=1, bos_token_id=0, @@ -130,5 +127,4 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.use_entity_aware_attention = use_entity_aware_attention diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index b9004c1d4970..97d1f1adfd9c 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -579,6 +579,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -600,7 +601,7 @@ def forward( all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -681,6 +682,7 @@ class LukePreTrainedModel(PreTrainedModel): config_class = LukeConfig base_model_prefix = "luke" + supports_gradient_checkpointing = True def _init_weights(self, module: nn.Module): """Initialize the weights""" @@ -699,6 +701,10 @@ def _init_weights(self, module: nn.Module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LukeEncoder): + module.gradient_checkpointing = value + LUKE_START_DOCSTRING = r""" diff --git a/src/transformers/models/m2m_100/configuration_m2m_100.py b/src/transformers/models/m2m_100/configuration_m2m_100.py index 765bcb4cd1b4..a4a0df749c29 100644 --- a/src/transformers/models/m2m_100/configuration_m2m_100.py +++ b/src/transformers/models/m2m_100/configuration_m2m_100.py @@ -79,8 +79,6 @@ class M2M100Config(PretrainedConfig): https://arxiv.org/abs/1909.11556>`__ for more details. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -121,7 +119,6 @@ def __init__( init_std=0.02, decoder_start_token_id=2, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -145,7 +142,6 @@ def __init__( self.decoder_layerdrop = decoder_layerdrop self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 3c0d246c404a..87af135bf82a 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -520,6 +520,7 @@ def forward( class M2M100PreTrainedModel(PreTrainedModel): config_class = M2M100Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -532,6 +533,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (M2M100Decoder, M2M100Encoder)): + module.gradient_checkpointing = value + M2M_100_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -693,6 +698,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -787,7 +793,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -857,6 +863,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -1013,12 +1020,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/marian/configuration_marian.py b/src/transformers/models/marian/configuration_marian.py index 1b974badfa66..825c7d707a73 100644 --- a/src/transformers/models/marian/configuration_marian.py +++ b/src/transformers/models/marian/configuration_marian.py @@ -78,8 +78,6 @@ class MarianConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=58100, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=58100, eos_token_id=0, forced_eos_token_id=0, @@ -153,7 +150,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index e2feb549b70f..a2df63735038 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -466,6 +466,7 @@ def forward( class MarianPreTrainedModel(PreTrainedModel): config_class = MarianConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -480,6 +481,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MarianDecoder, MarianEncoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -656,6 +661,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -750,7 +756,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -816,6 +822,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -987,12 +994,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/mbart/configuration_mbart.py b/src/transformers/models/mbart/configuration_mbart.py index 05857241b4ba..d1eb27c0e808 100644 --- a/src/transformers/models/mbart/configuration_mbart.py +++ b/src/transformers/models/mbart/configuration_mbart.py @@ -82,8 +82,6 @@ class MBartConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -131,7 +129,6 @@ def __init__( init_std=0.02, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -157,7 +154,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 0412eccaaab7..0ebb5a1a8f34 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -479,6 +479,7 @@ def forward(self, hidden_states: torch.Tensor): class MBartPreTrainedModel(PreTrainedModel): config_class = MBartConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -491,6 +492,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (MBartDecoder, MBartDecoder)): + module.gradient_checkpointing = value + @property def dummy_inputs(self): pad_token = self.config.pad_token_id @@ -685,6 +690,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -780,7 +786,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -850,6 +856,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1022,12 +1029,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/megatron_bert/configuration_megatron_bert.py b/src/transformers/models/megatron_bert/configuration_megatron_bert.py index 19171e70da1b..d6e32cd49630 100644 --- a/src/transformers/models/megatron_bert/configuration_megatron_bert.py +++ b/src/transformers/models/megatron_bert/configuration_megatron_bert.py @@ -65,8 +65,6 @@ class MegatronBertConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on @@ -108,7 +106,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, position_embedding_type="absolute", use_cache=True, **kwargs @@ -127,6 +124,5 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.position_embedding_type = position_embedding_type self.use_cache = use_cache diff --git a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py index 3d7f03dcbb76..1d33ef91e624 100644 --- a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py +++ b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py @@ -180,7 +180,6 @@ def convert_megatron_checkpoint(args, input_state_dict): "type_vocab_size": 2, "initializer_range": 0.2, "layer_norm_eps": 1e-12, - "gradient_checkpointing": False, "position_embedding_type": "absolute", "use_cache": False, } diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 3c49ea88b873..80337b2dabf9 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -508,6 +508,7 @@ def __init__(self, config): # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one # is simply the final LN (Transformer's BERT has it attached to each hidden layer). self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.gradient_checkpointing = False def forward( self, @@ -534,12 +535,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -705,6 +705,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel): config_class = MegatronBertConfig load_tf_weights = load_tf_weights_in_megatron_bert base_model_prefix = "bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -719,6 +720,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, MegatronBertEncoder): + module.gradient_checkpointing = value + @dataclass # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert diff --git a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py index 7dd5d18b64f7..d930c5bb6506 100644 --- a/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py +++ b/src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py @@ -257,7 +257,6 @@ def main(): summary_proj_to_labels=True, summary_first_dropout=0.1, scale_attn_weights=True, - gradient_checkpointing=False, use_cache=True, bos_token_id=50256, eos_token_id=50256, diff --git a/src/transformers/models/pegasus/configuration_pegasus.py b/src/transformers/models/pegasus/configuration_pegasus.py index 2e815c2e486b..8cf76c482bc1 100644 --- a/src/transformers/models/pegasus/configuration_pegasus.py +++ b/src/transformers/models/pegasus/configuration_pegasus.py @@ -78,8 +78,6 @@ class PegasusConfig(PretrainedConfig): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): The LayerDrop probability for the decoder. See the `LayerDrop paper `__ for more details. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): Scale embeddings by diving by sqrt(d_model). use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): @@ -128,7 +126,6 @@ def __init__( decoder_start_token_id=0, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, pad_token_id=0, eos_token_id=1, forced_eos_token_id=1, @@ -153,7 +150,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ab1009b33937..2728f144b352 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -466,6 +466,7 @@ def forward( class PegasusPreTrainedModel(PreTrainedModel): config_class = PegasusConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -480,6 +481,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (PegasusDecoder, PegasusEncoder)): + module.gradient_checkpointing = value + PEGASUS_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -646,6 +651,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def resize_position_embeddings(self, new_num_position_embeddings: int): """ @@ -770,7 +776,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -840,6 +846,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1040,12 +1047,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/prophetnet/configuration_prophetnet.py b/src/transformers/models/prophetnet/configuration_prophetnet.py index c19e4a106f2b..074bad3e24d8 100644 --- a/src/transformers/models/prophetnet/configuration_prophetnet.py +++ b/src/transformers/models/prophetnet/configuration_prophetnet.py @@ -92,8 +92,6 @@ class ProphetNetConfig(PretrainedConfig): smoothing is performed. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ model_type = "prophetnet" keys_to_ignore_at_inference = ["past_key_values"] @@ -124,7 +122,6 @@ def __init__( num_buckets=32, relative_max_distance=128, disable_ngram_loss=False, - gradient_checkpointing=False, eps=0.0, use_cache=True, pad_token_id=0, @@ -158,9 +155,6 @@ def __init__( self.use_cache = use_cache - # 4 Training Args (should be removed soon) - self.gradient_checkpointing = gradient_checkpointing - super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index ed4c79265749..9f72a35f0dfd 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -547,6 +547,7 @@ class ProphetNetDecoderLMOutput(ModelOutput): class ProphetNetPreTrainedModel(PreTrainedModel): config_class = ProphetNetConfig base_model_prefix = "prophetnet" + supports_gradient_checkpointing = True def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -558,6 +559,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)): + module.gradient_checkpointing = value + def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1262,6 +1267,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.word_embeddings @@ -1337,7 +1343,7 @@ def forward( if output_hidden_states: encoder_hidden_states = encoder_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1406,6 +1412,7 @@ def __init__(self, config: ProphetNetConfig, word_embeddings: nn.Embedding = Non self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.word_embeddings @@ -1566,12 +1573,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/rembert/configuration_rembert.py b/src/transformers/models/rembert/configuration_rembert.py index d9432d20a985..51c899dfc985 100644 --- a/src/transformers/models/rembert/configuration_rembert.py +++ b/src/transformers/models/rembert/configuration_rembert.py @@ -76,8 +76,6 @@ class RemBertConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 46524ce9cbb8..ab3874865afc 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -501,6 +501,7 @@ def __init__(self, config): self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size) self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -528,12 +529,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -648,6 +648,7 @@ class RemBertPreTrainedModel(PreTrainedModel): config_class = RemBertConfig load_tf_weights = load_tf_weights_in_rembert base_model_prefix = "rembert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -666,6 +667,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RemBertEncoder): + module.gradient_checkpointing = value + REMBERT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 09472da7674a..f74954ac6428 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -469,6 +469,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -495,12 +496,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -585,6 +585,7 @@ class RobertaPreTrainedModel(PreTrainedModel): config_class = RobertaConfig base_model_prefix = "roberta" + supports_gradient_checkpointing = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -603,6 +604,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RobertaEncoder): + module.gradient_checkpointing = value + def update_keys_to_ignore(self, config, del_keys_to_ignore): """Remove some keys from ignore list""" if not config.tie_word_embeddings: diff --git a/src/transformers/models/roformer/configuration_roformer.py b/src/transformers/models/roformer/configuration_roformer.py index 945d1064a10e..5027b3be1fb8 100644 --- a/src/transformers/models/roformer/configuration_roformer.py +++ b/src/transformers/models/roformer/configuration_roformer.py @@ -80,8 +80,6 @@ class RoFormerConfig(PretrainedConfig): relevant if ``config.is_decoder=True``. rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not apply rotary position embeddings on value layer. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -114,7 +112,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, rotary_value=False, use_cache=True, **kwargs @@ -134,6 +131,5 @@ def __init__( self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing self.rotary_value = rotary_value self.use_cache = use_cache diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index f08d3e5c8f6a..23929a4c6131 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -551,6 +551,7 @@ def __init__(self, config): config.max_position_embeddings, config.hidden_size // config.num_attention_heads ) self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -580,12 +581,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -705,6 +705,7 @@ class RoFormerPreTrainedModel(PreTrainedModel): config_class = RoFormerConfig load_tf_weights = load_tf_weights_in_roformer base_model_prefix = "roformer" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [] _keys_to_ignore_on_load_unexpected = [ r"roformer\.embeddings_project\.weight", @@ -729,6 +730,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RoFormerEncoder): + module.gradient_checkpointing = value + ROFORMER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/speech_to_text/configuration_speech_to_text.py b/src/transformers/models/speech_to_text/configuration_speech_to_text.py index ff16601030db..821362d2e636 100644 --- a/src/transformers/models/speech_to_text/configuration_speech_to_text.py +++ b/src/transformers/models/speech_to_text/configuration_speech_to_text.py @@ -134,7 +134,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -165,7 +164,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 491983751b66..a99e95e9d740 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -531,6 +531,7 @@ def forward( class Speech2TextPreTrainedModel(PreTrainedModel): config_class = Speech2TextConfig base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -543,6 +544,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): """ Computes the output length of the convolutional layers @@ -711,6 +716,7 @@ def __init__(self, config: Speech2TextConfig): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -795,7 +801,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -863,6 +869,7 @@ def __init__(self, config: Speech2TextConfig): self.layer_norm = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1032,11 +1039,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False diff --git a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py index f1f950599058..abeac09a105d 100644 --- a/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/configuration_speech_to_text_2.py @@ -108,7 +108,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=True, - gradient_checkpointing=False, pad_token_id=1, bos_token_id=0, eos_token_id=2, @@ -130,7 +129,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = decoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 35bd8f308817..20aa2b87884e 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -407,6 +407,7 @@ def forward( class Speech2Text2PreTrainedModel(PreTrainedModel): config_class = Speech2Text2Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -419,6 +420,10 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, Speech2Text2Decoder): + module.gradient_checkpointing = value + SPEECH_TO_TEXT_2_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic @@ -465,6 +470,7 @@ def __init__(self, config: Speech2Text2Config): self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)]) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -635,11 +641,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." ) use_cache = False diff --git a/src/transformers/models/splinter/configuration_splinter.py b/src/transformers/models/splinter/configuration_splinter.py index 879451bbe50b..986e436fe757 100644 --- a/src/transformers/models/splinter/configuration_splinter.py +++ b/src/transformers/models/splinter/configuration_splinter.py @@ -71,8 +71,6 @@ class SplinterConfig(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. question_token_id (:obj:`int`, `optional`, defaults to 104): The id of the ``[QUESTION]`` token. diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 1296db12508d..381a280ebb22 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -409,6 +409,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -435,12 +436,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -509,6 +509,7 @@ class SplinterPreTrainedModel(PreTrainedModel): config_class = SplinterConfig base_model_prefix = "splinter" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights @@ -528,6 +529,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, SplinterEncoder): + module.gradient_checkpointing = value + SPLINTER_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. Use diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 9a406591279e..bb16a5fb0f50 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -77,8 +77,6 @@ class T5Config(PretrainedConfig): the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ model_type = "t5" keys_to_ignore_at_inference = ["past_key_values"] @@ -102,7 +100,6 @@ def __init__( use_cache=True, pad_token_id=0, eos_token_id=1, - gradient_checkpointing=False, **kwargs ): self.vocab_size = vocab_size @@ -120,7 +117,6 @@ def __init__( self.initializer_factor = initializer_factor self.feed_forward_proj = feed_forward_proj self.use_cache = use_cache - self.gradient_checkpointing = gradient_checkpointing super().__init__( pad_token_id=pad_token_id, eos_token_id=eos_token_id, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 27ef440bfb15..f18c9e66f5e6 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -325,7 +325,7 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False): if self.has_relative_attention_bias: self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.pruned_heads = set() - self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) + self.gradient_checkpointing = False def prune_heads(self, heads): if len(heads) == 0: @@ -489,7 +489,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ) - if self.training and self.gradient_checkpointing: + if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias(real_seq_length, key_length) @@ -715,6 +715,7 @@ class T5PreTrainedModel(PreTrainedModel): load_tf_weights = load_tf_weights_in_t5 base_model_prefix = "transformer" is_parallelizable = True + supports_gradient_checkpointing = True @property def dummy_inputs(self): @@ -769,6 +770,10 @@ def _init_weights(self, module): if module.has_relative_attention_bias: module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -813,6 +818,7 @@ def __init__(self, config, embed_tokens=None): # Model parallel self.model_parallel = False self.device_map = None + self.gradient_checkpointing = False @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -968,11 +974,10 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warn( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False diff --git a/src/transformers/models/tapas/configuration_tapas.py b/src/transformers/models/tapas/configuration_tapas.py index 834cae0c7ea6..d59dc00f4515 100644 --- a/src/transformers/models/tapas/configuration_tapas.py +++ b/src/transformers/models/tapas/configuration_tapas.py @@ -73,8 +73,6 @@ class TapasConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to use gradient checkpointing to save memory at the expense of a slower backward pass. positive_label_weight (:obj:`float`, `optional`, defaults to 10.0): Weight for positive labels. num_aggregation_labels (:obj:`int`, `optional`, defaults to 0): @@ -159,7 +157,6 @@ def __init__( initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, - gradient_checkpointing=False, positive_label_weight=10.0, num_aggregation_labels=0, aggregation_loss_weight=1.0, @@ -202,7 +199,6 @@ def __init__( self.type_vocab_sizes = type_vocab_sizes self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps - self.gradient_checkpointing = gradient_checkpointing # Fine-tuning task hyperparameters self.positive_label_weight = positive_label_weight diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index 29d4a3ef4f34..9506216522dc 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -627,6 +627,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -649,7 +650,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -763,6 +764,7 @@ class TapasPreTrainedModel(PreTrainedModel): config_class = TapasConfig base_model_prefix = "tapas" + supports_gradient_checkpointing = True # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): @@ -781,6 +783,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, TapasEncoder): + module.gradient_checkpointing = value + TAPAS_START_DOCSTRING = r""" This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 21f5e01362ce..c6c01010081e 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -398,6 +398,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -417,7 +418,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -532,6 +533,7 @@ class VisualBertPreTrainedModel(PreTrainedModel): config_class = VisualBertConfig base_model_prefix = "visual_bert" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -547,6 +549,10 @@ def _init_weights(self, module): if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, VisualBertEncoder): + module.gradient_checkpointing = value + @dataclass class VisualBertForPreTrainingOutput(ModelOutput): diff --git a/src/transformers/models/vit/configuration_vit.py b/src/transformers/models/vit/configuration_vit.py index 5e53df4cddfd..9c64be5141bc 100644 --- a/src/transformers/models/vit/configuration_vit.py +++ b/src/transformers/models/vit/configuration_vit.py @@ -57,8 +57,6 @@ class ViTConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. image_size (:obj:`int`, `optional`, defaults to :obj:`224`): The size (resolution) of each image. patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 5b147f285632..78911f7b4186 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -352,6 +352,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -370,7 +371,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -411,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel): config_class = ViTConfig base_model_prefix = "vit" + supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" @@ -428,6 +430,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ViTEncoder): + module.gradient_checkpointing = value + VIT_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ subclass. Use diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index d82e6a6d3457..49818feb22df 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -138,8 +138,6 @@ class Wav2Vec2Config(PretrainedConfig): instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If True, use gradient checkpointing to save memory at the expense of slower backward pass. Example:: @@ -198,7 +196,6 @@ def __init__( ctc_zero_infinity=False, use_weighted_layer_sum=False, classifier_proj_size=256, - gradient_checkpointing=False, pad_token_id=0, bos_token_id=1, eos_token_id=2, @@ -229,7 +226,6 @@ def __init__( self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm - self.gradient_checkpointing = gradient_checkpointing self.use_weighted_layer_sum = use_weighted_layer_sum self.classifier_proj_size = classifier_proj_size diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ade54417f1df..71f431ca976d 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -590,6 +590,7 @@ def __init__(self, config): self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -629,7 +630,7 @@ def forward( skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -676,6 +677,7 @@ def __init__(self, config): self.layers = nn.ModuleList( [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) + self.gradient_checkpointing = False def forward( self, @@ -715,7 +717,7 @@ def forward( if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): @@ -842,6 +844,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): config_class = Wav2Vec2Config base_model_prefix = "wav2vec2" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -864,6 +867,10 @@ def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm)): + module.gradient_checkpointing = value + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): """ Computes the output length of the convolutional layers diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f5aa74616c47..d39a24bf46cb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -990,7 +990,7 @@ def _wrap_model(self, model, training=True): elif isinstance(model, PreTrainedModel): # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - find_unused_parameters = not getattr(model.config, "gradient_checkpointing", False) + find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False) else: find_unused_parameters = True model = nn.parallel.DistributedDataParallel( @@ -1162,6 +1162,10 @@ def train( self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + model = self._wrap_model(self.model_wrapped) # for the rest of this function `model` is the outside model, whether it was wrapped or not diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d34622abc03e..ce330a254c41 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -372,6 +372,8 @@ class TrainingArguments: hub_token (:obj:`str`, `optional`): The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with :obj:`huggingface-cli login`. + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. """ output_dir: str = field( @@ -650,6 +652,12 @@ class TrainingArguments: metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, ) hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + gradient_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." + }, + ) # Deprecated arguments push_to_hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py index 93da35a5d98b..6978a3ddf3fd 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/configuration_{{cookiecutter.lowercase_modelname}}.py @@ -72,8 +72,6 @@ class {{cookiecutter.camelcase_modelname}}Config(PretrainedConfig): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether or not the model should return the last key/values attentions (not used by all models). Only relevant if ``config.is_decoder=True``. - gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): - If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass. {% else -%} vocab_size (:obj:`int`, `optional`, defaults to 50265): Vocabulary size of the {{cookiecutter.modelname}} model. Defines the number of different tokens that can be represented by the @@ -186,7 +184,6 @@ def __init__( decoder_start_token_id=2, classifier_dropout=0.0, scale_embedding=False, - gradient_checkpointing=False, {% endif -%} pad_token_id=1, bos_token_id=0, @@ -225,7 +222,6 @@ def __init__( self.classifier_dropout = classifier_dropout self.use_cache = use_cache self.num_hidden_layers = encoder_layers - self.gradient_checkpointing = gradient_checkpointing self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True {% endif -%} diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 835382396cd5..b0482f706212 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -513,6 +513,7 @@ def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([{{cookiecutter.camelcase_modelname}}Layer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False def forward( self, @@ -539,12 +540,11 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: logger.warning( - "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " - "`use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -664,6 +664,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}} base_model_prefix = "{{cookiecutter.lowercase_modelname}}" + supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): @@ -682,6 +683,10 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): + module.gradient_checkpointing = value + {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" This model is a PyTorch `torch.nn.Module `_ sub-class. @@ -2006,6 +2011,7 @@ def forward(self, hidden_states: torch.Tensor): class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config base_model_prefix = "model" + supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std @@ -2017,16 +2023,10 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - - @property - def dummy_inputs(self): - pad_token = self.config.pad_token_id - input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) - dummy_inputs = { - "attention_mask": input_ids.ne(pad_token), - "input_ids": input_ids, - } - return dummy_inputs + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): + module.gradient_checkpointing = value {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" @@ -2213,6 +2213,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tok self.layernorm_embedding = nn.LayerNorm(embed_dim) self.init_weights() + self.gradient_checkpointing = False def forward( self, @@ -2309,7 +2310,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2376,6 +2377,7 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tok self.layernorm_embedding = nn.LayerNorm(config.d_model) self.init_weights() + self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -2545,10 +2547,10 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False) and self.training: + if self.gradient_checkpointing and self.training: if use_cache: - logger.warning("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") + logger.warning("`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`...") use_cache = False def create_custom_forward(module): diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index b4c670356f5c..6557936d59b0 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -224,6 +224,27 @@ def test_training(self): loss = model(**inputs).loss loss.backward() + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training: + return + + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: + continue + # we don't test BeitForMaskedImageModeling + if model_class.__name__ == "BeitForMaskedImageModeling": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_initialization(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a306e630480f..0e3b31f5d324 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -369,15 +369,14 @@ def test_training(self): def test_training_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - if not self.model_tester.is_training or not hasattr(config, "gradient_checkpointing"): + if not self.model_tester.is_training: return - config.gradient_checkpointing = True config.use_cache = False config.return_dict = True for model_class in self.all_model_classes: - if model_class in get_values(MODEL_MAPPING): + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: continue model = model_class(config) model.to(torch_device) diff --git a/tests/test_modeling_deit.py b/tests/test_modeling_deit.py index c689a90af785..119e0988916b 100644 --- a/tests/test_modeling_deit.py +++ b/tests/test_modeling_deit.py @@ -20,6 +20,7 @@ from transformers import DeiTConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available +from transformers.models.auto import get_values from transformers.testing_utils import require_torch, require_vision, slow, torch_device from .test_configuration_common import ConfigTester @@ -340,7 +341,7 @@ def test_training(self): for model_class in self.all_model_classes: # DeiTForImageClassificationWithTeacher supports inference-only if ( - model_class in MODEL_MAPPING.values() + model_class in get_values(MODEL_MAPPING) or model_class.__name__ == "DeiTForImageClassificationWithTeacher" ): continue @@ -351,6 +352,27 @@ def test_training(self): loss = model(**inputs).loss loss.backward() + def test_training_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.model_tester.is_training: + return + + config.use_cache = False + config.return_dict = True + + for model_class in self.all_model_classes: + if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: + continue + # DeiTForImageClassificationWithTeacher supports inference-only + if model_class.__name__ == "DeiTForImageClassificationWithTeacher": + continue + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + loss = model(**inputs).loss + loss.backward() + def test_for_image_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_image_classification(*config_and_inputs) diff --git a/tests/test_modeling_flax_gpt2.py b/tests/test_modeling_flax_gpt2.py index 0c793ebd27b7..3b2e43680e60 100644 --- a/tests/test_modeling_flax_gpt2.py +++ b/tests/test_modeling_flax_gpt2.py @@ -82,7 +82,7 @@ def __init__( self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -100,7 +100,6 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) return (config, input_ids, input_mask) diff --git a/tests/test_modeling_flax_gpt_neo.py b/tests/test_modeling_flax_gpt_neo.py index 2916bec5b94f..7d0d832295a4 100644 --- a/tests/test_modeling_flax_gpt_neo.py +++ b/tests/test_modeling_flax_gpt_neo.py @@ -86,7 +86,7 @@ def __init__( self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -105,7 +105,6 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): pad_token_id=self.pad_token_id, window_size=self.window_size, attention_types=self.attention_types, - gradient_checkpointing=gradient_checkpointing, ) return (config, input_ids, input_mask) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 91d2edcdc838..214a17f0508f 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -96,7 +96,7 @@ def __init__( def get_large_model_config(self): return GPT2Config.from_pretrained("gpt2") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -119,7 +119,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=gradient_checkpointing) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -135,7 +135,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPT2Config( vocab_size=self.vocab_size, n_embd=self.hidden_size, @@ -149,11 +149,10 @@ def get_config(self, gradient_checkpointing=False): n_ctx=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) def prepare_config_and_inputs_for_decoder(self): @@ -322,9 +321,13 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPT2LMHeadModel(config) model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) self.parent.assertEqual(result.loss.shape, ()) @@ -478,8 +481,8 @@ def test_gpt2_token_classification_model(self): self.model_tester.create_and_check_gpt2_for_token_classification(*config_and_inputs) def test_gpt2_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) @slow def test_batch_generation(self): @@ -612,7 +615,11 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_gpt2(self): for checkpointing in [True, False]: - model = GPT2LMHeadModel.from_pretrained("gpt2", gradient_checkpointing=checkpointing) + model = GPT2LMHeadModel.from_pretrained("gpt2") + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() model.to(torch_device) input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog expected_output_ids = [ diff --git a/tests/test_modeling_gpt_neo.py b/tests/test_modeling_gpt_neo.py index fa1b63b4f616..a8e5b4babc57 100644 --- a/tests/test_modeling_gpt_neo.py +++ b/tests/test_modeling_gpt_neo.py @@ -97,7 +97,7 @@ def __init__( def get_large_model_config(self): return GPTNeoConfig.from_pretrained("gpt_neo") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -120,7 +120,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=False) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -136,18 +136,17 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPTNeoConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_layers=self.num_hidden_layers, num_heads=self.num_attention_heads, max_position_embeddings=self.max_position_embeddings, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, window_size=self.window_size, attention_types=self.attention_types, ) @@ -329,8 +328,12 @@ def create_and_check_gpt_neo_for_sequence_classification( result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPTNeoForCausalLM(config) + if gradient_checkpointing: + model.gradient_checkpointing_enable() model.to(torch_device) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) @@ -411,8 +414,8 @@ def test_gpt_neo_sequence_classification_model(self): self.model_tester.create_and_check_gpt_neo_for_sequence_classification(*config_and_inputs) def test_gpt_neo_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) def _get_hidden_states(self): return torch.tensor( @@ -473,7 +476,10 @@ def tokenizer(self): def test_lm_generate_gpt_neo(self): for checkpointing in [True, False]: model = self.model - model.config.gradient_checkpointing = checkpointing + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog # fmt: off # The dog-eared copy of the book, which is a collection of essays by the late author, diff --git a/tests/test_modeling_gptj.py b/tests/test_modeling_gptj.py index 5739aed5a1f7..06979a2c7f82 100644 --- a/tests/test_modeling_gptj.py +++ b/tests/test_modeling_gptj.py @@ -92,7 +92,7 @@ def __init__( def get_large_model_config(self): return GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -115,7 +115,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = self.get_config(gradient_checkpointing=gradient_checkpointing) + config = self.get_config() head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) @@ -131,7 +131,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): choice_labels, ) - def get_config(self, gradient_checkpointing=False): + def get_config(self): return GPTJConfig( vocab_size=self.vocab_size, n_embd=self.hidden_size, @@ -145,11 +145,10 @@ def get_config(self, gradient_checkpointing=False): n_ctx=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range, - use_cache=not gradient_checkpointing, + use_cache=True, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, - gradient_checkpointing=gradient_checkpointing, ) def prepare_config_and_inputs_for_decoder(self): @@ -318,8 +317,12 @@ def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mas self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_forward_and_backwards(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): model = GPTJForCausalLM(config) + if gradient_checkpointing: + model.gradient_checkpointing_enable() model.to(torch_device) result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids) @@ -390,8 +393,8 @@ def test_gptj_lm_head_model(self): self.model_tester.create_and_check_lm_head_model(*config_and_inputs) def test_gptj_gradient_checkpointing(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs, gradient_checkpointing=True) @slow def test_batch_generation(self): @@ -464,7 +467,11 @@ class GPTJModelLanguageGenerationTest(unittest.TestCase): @slow def test_lm_generate_gptj(self): for checkpointing in [True, False]: - model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", gradient_checkpointing=checkpointing) + model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B") + if checkpointing: + model.gradient_checkpointing_enable() + else: + model.gradient_checkpointing_disable() model.to(torch_device) input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog expected_output_ids = [