diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index 2c4bd78efd..a2c4b10a07 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -351,11 +351,11 @@ def train( dtype=torch.float32, ) - # Update parameters - self.optimizer.step() - self.scheduler.step() + # Update parameters + self.optimizer.step() + self.scheduler.step() - losses.append(torch.tensor(mb_losses).sum().item()) + losses.append(torch.tensor(mb_losses).sum().item()) # Compute global loss across all ranks with torch.no_grad():