Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def prune_heads(self, heads_to_prune: Dict[int, List[int]]):

self.base_model._prune_heads(heads_to_prune)

def gradient_checkpointing_enable(self, flag: bool = True):
def gradient_checkpointing_enable(self):
"""
Activates gradient checkpointing for the current model.

Expand All @@ -950,7 +950,7 @@ def gradient_checkpointing_enable(self, flag: bool = True):
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))

def gradient_checkpointing_disable(self, flag: bool = True):
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.

Expand All @@ -960,6 +960,16 @@ def gradient_checkpointing_disable(self, flag: bool = True):
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))

@property
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.

Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ def _wrap_model(self, model, training=True):
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False)
find_unused_parameters = not model.is_gradient_checkpointing
else:
find_unused_parameters = True
model = nn.parallel.DistributedDataParallel(
Expand Down
19 changes: 19 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,25 @@ def test_save_load_keys_to_ignore_on_save(self):
)
self.assertTrue(len(load_result.unexpected_keys) == 0)

def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue

# at init model should have gradient checkpointing disabled
model = model_class(config)
self.assertFalse(model.is_gradient_checkpointing)

# check enable works
model.gradient_checkpointing_enable()
self.assertTrue(model.is_gradient_checkpointing)

# check disable works
model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing)

def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
Expand Down