diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index f7894ddedbe..a3a80e461a1 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -100,6 +100,9 @@ def __init__( if hasattr(pretrained_model, "gradient_checkpointing_enable"): self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable + if hasattr(pretrained_model, "enable_input_require_grads"): + self.enable_input_require_grads = pretrained_model.enable_input_require_grads + self.supports_rm_adapter = supports_rm_adapter self.rm_adapter_name = rm_adapter_name self.policy_adapter_name = "default" diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index b70610d9503..786d9d0d38d 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -124,6 +124,8 @@ class PPOConfig: """Score clipping""" whiten_rewards: bool = False """Whiten the rewards before compute advantages""" + gradient_checkpointing: bool = False + """Enable gradient checkpointing""" # computed hyperparameters at runtime; we use `tyro.conf.Suppress` to hide them from the help text is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 0a60dcaffb3..e9613834634 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -319,6 +319,18 @@ def __init__( self.accelerator.state, "deepspeed_plugin" ) + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + if hasattr(self.model, "enable_input_require_grads"): + self.model.enable_input_require_grads() + else: + # For backward compatibility with older versions of transformers + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.model.pretrained_model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + ( self.model, self.optimizer,