diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index 3a94ffa98c..b8ebe9c348 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -919,7 +919,12 @@ def train( torch.cuda.empty_cache() # Update parameters. - update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step() + if not eval_mode: + update_successful, grad_norm, num_zeros_in_grad = ( + self.optimizer.step() + ) + else: + update_successful, grad_norm, num_zeros_in_grad = (True, 0.0, 0.0) # when freezing sub-models we may have a mixture of successful and unsucessful ranks, # so we must gather across mp ranks