diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 4ee48901bc..6f57762af3 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -1770,6 +1770,8 @@ def save_checkpoint( if not is_training: self.model.eval() + if self.should_disable_forward_pre_hook: + self.disable_forward_pre_hook() save_checkpoint( state=self.mcore_state, model=[self.model], @@ -1784,6 +1786,8 @@ def save_checkpoint( blocking=True, terminate=True, ) + if self.should_disable_forward_pre_hook: + self.enable_forward_pre_hook() if not is_training: # Restore training state if it was changed self.model.train()