From 6ceb34581811841ac150e8828846f26ada6b58a8 Mon Sep 17 00:00:00 2001 From: Sadegh Mahdavi Date: Tue, 9 Dec 2025 12:31:08 -0800 Subject: [PATCH 1/4] allow zero grad norm for consistency with megatron Signed-off-by: Sadegh Mahdavi --- .../policy/workers/dtensor_policy_worker.py | 4 ++- .../workers/dtensor_policy_worker_v2.py | 35 +++++++++---------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index e9c57cfb55..c003f194cc 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -191,6 +191,8 @@ def __init__( self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] + if self.max_grad_norm == 0.0: # allow zero grad norm for consistency with megatron + self.max_grad_norm = None if self.cfg["precision"] == "float32": self.dtype = torch.float32 @@ -829,7 +831,7 @@ def train( tp_group=self.tp_mesh.get_group(), dtype=torch.float32, ) - if self.max_grad_norm is not None: + if self.max_grad_norm is not None and self.max_grad_norm > 0: clip_grad_by_total_norm_( self.model.parameters(), max_grad_norm=self.max_grad_norm, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 49e1360c57..1dc89cf43b 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -235,6 +235,8 @@ def __init__( self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] + if self.max_grad_norm == 0.0: # allow zero grad norm for consistency with megatron + self.max_grad_norm = None try: self.dtype = STRING_TO_DTYPE[self.cfg["precision"]] @@ -952,25 +954,20 @@ def train( grad_norm: Optional[float | torch.Tensor] = None if not eval_mode: - grad_norm = scale_grads_and_clip_grad_norm( - self.max_grad_norm, - [self.model], - norm_type=2.0, - pp_enabled=False, - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - ep_axis_name="ep" - if self.moe_mesh is not None - and "ep" in self.moe_mesh.mesh_dim_names - else None, - pp_axis_name=None, - foreach=True, - num_label_tokens=1, - dp_group_size=self.dp_size * self.cp_size, - ) - grad_norm = torch.tensor( - grad_norm, device="cpu", dtype=torch.float32 - ) + with torch.no_grad(): + grad_norm = get_grad_norm( + self.model.parameters(), + dp_cp_group=self.dp_cp_mesh.get_group(), + tp_group=self.tp_mesh.get_group(), + dtype=torch.float32, + ) + if self.max_grad_norm is not None and self.max_grad_norm > 0: + clip_grad_by_total_norm_( + self.model.parameters(), + max_grad_norm=self.max_grad_norm, + total_norm=grad_norm, + ) + grad_norm = torch.tensor([grad_norm]) # Update parameters self.optimizer.step() From 36874c6fe1e916be1167148853d803ff110b4d74 Mon Sep 17 00:00:00 2001 From: Sadegh Mahdavi Date: Tue, 13 Jan 2026 13:18:37 -0800 Subject: [PATCH 2/4] fix Signed-off-by: Sadegh Mahdavi --- .../policy/workers/dtensor_policy_worker.py | 2 +- .../workers/dtensor_policy_worker_v2.py | 33 +++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index c003f194cc..1f8582a65d 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -831,7 +831,7 @@ def train( tp_group=self.tp_mesh.get_group(), dtype=torch.float32, ) - if self.max_grad_norm is not None and self.max_grad_norm > 0: + if self.max_grad_norm is not None: clip_grad_by_total_norm_( self.model.parameters(), max_grad_norm=self.max_grad_norm, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 1dc89cf43b..047ebbcd3b 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -954,20 +954,25 @@ def train( grad_norm: Optional[float | torch.Tensor] = None if not eval_mode: - with torch.no_grad(): - grad_norm = get_grad_norm( - self.model.parameters(), - dp_cp_group=self.dp_cp_mesh.get_group(), - tp_group=self.tp_mesh.get_group(), - dtype=torch.float32, - ) - if self.max_grad_norm is not None and self.max_grad_norm > 0: - clip_grad_by_total_norm_( - self.model.parameters(), - max_grad_norm=self.max_grad_norm, - total_norm=grad_norm, - ) - grad_norm = torch.tensor([grad_norm]) + grad_norm = scale_grads_and_clip_grad_norm( + self.max_grad_norm, + [self.model], + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" + if self.moe_mesh is not None + and "ep" in self.moe_mesh.mesh_dim_names + else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=1, + dp_group_size=self.dp_size * self.cp_size, + ) + grad_norm = torch.tensor( + grad_norm, device="cpu", dtype=torch.float32 + ) # Update parameters self.optimizer.step() From de24b4b8a33e2bb349644631a4bf250a8fa2ef47 Mon Sep 17 00:00:00 2001 From: Sadegh Mahdavi Date: Tue, 13 Jan 2026 13:25:56 -0800 Subject: [PATCH 3/4] fix Signed-off-by: Sadegh Mahdavi --- nemo_rl/models/policy/workers/dtensor_policy_worker.py | 3 ++- nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 1f8582a65d..69a8c42092 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -191,7 +191,8 @@ def __init__( self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] - if self.max_grad_norm == 0.0: # allow zero grad norm for consistency with megatron + # allow zero grad norm for consistency with megatron + if self.max_grad_norm == 0.0: self.max_grad_norm = None if self.cfg["precision"] == "float32": diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 047ebbcd3b..ff3eb149ae 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -235,7 +235,8 @@ def __init__( self.cpu_offload = self.cfg["dtensor_cfg"]["cpu_offload"] self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] - if self.max_grad_norm == 0.0: # allow zero grad norm for consistency with megatron + # allow zero grad norm for consistency with megatron + if self.max_grad_norm == 0.0: self.max_grad_norm = None try: From 030e6e06a95073da2ca4a33be9cb1a79861e834a Mon Sep 17 00:00:00 2001 From: Sadegh Mahdavi Date: Wed, 14 Jan 2026 06:07:09 -0800 Subject: [PATCH 4/4] linting fix Signed-off-by: Sadegh Mahdavi --- nemo_rl/models/policy/workers/dtensor_policy_worker.py | 2 +- nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 69a8c42092..af483adb6e 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -192,7 +192,7 @@ def __init__( self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] # allow zero grad norm for consistency with megatron - if self.max_grad_norm == 0.0: + if self.max_grad_norm == 0.0: self.max_grad_norm = None if self.cfg["precision"] == "float32": diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index ff3eb149ae..d35c931a69 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -236,7 +236,7 @@ def __init__( self.offload_optimizer_for_logprob = self.cfg["offload_optimizer_for_logprob"] self.max_grad_norm = self.cfg["max_grad_norm"] # allow zero grad norm for consistency with megatron - if self.max_grad_norm == 0.0: + if self.max_grad_norm == 0.0: self.max_grad_norm = None try: