diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index e9c57cfb55..af483adb6e 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -191,6 +191,9 @@ def __init__( self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] + # allow zero grad norm for consistency with megatron + if self.max_grad_norm == 0.0: + self.max_grad_norm = None if self.cfg["precision"] == "float32": self.dtype = torch.float32 diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 49e1360c57..d35c931a69 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -235,6 +235,9 @@ def __init__( self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] + # allow zero grad norm for consistency with megatron + if self.max_grad_norm == 0.0: + self.max_grad_norm = None try: self.dtype = STRING_TO_DTYPE[self.cfg["precision"]]