diff --git a/nemo_reinforcer/algorithms/dpo.py b/nemo_reinforcer/algorithms/dpo.py index 64d14164d5..c4a6af834a 100644 --- a/nemo_reinforcer/algorithms/dpo.py +++ b/nemo_reinforcer/algorithms/dpo.py @@ -482,6 +482,7 @@ def dpo_train( losses = train_results["loss"] metrics = { "loss": train_results["loss"].numpy(), + "grad_norm": train_results["grad_norm"].numpy(), } metrics.update(train_results["all_mb_metrics"]) for k, v in metrics.items(): diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 592a4bc4d8..3db313fcd2 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -576,6 +576,7 @@ def grpo_train( metrics = { "loss": train_results["loss"].numpy(), "reward": rewards.numpy(), + "grad_norm": train_results["grad_norm"].numpy(), } metrics.update(train_results["all_mb_metrics"]) for k, v in metrics.items(): diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index a42967aa31..4bcc9a8a41 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -486,6 +486,7 @@ def sft_train( losses = train_results["loss"] metrics = { "loss": train_results["loss"].numpy(), + "grad_norm": train_results["grad_norm"].numpy(), } metrics.update(train_results["all_mb_metrics"]) for k, v in metrics.items(): diff --git a/nemo_reinforcer/models/policy/dtensor_policy_worker.py b/nemo_reinforcer/models/policy/dtensor_policy_worker.py index cf0f06bbfa..50016c8cef 100644 --- a/nemo_reinforcer/models/policy/dtensor_policy_worker.py +++ b/nemo_reinforcer/models/policy/dtensor_policy_worker.py @@ -369,6 +369,7 @@ def train( total_norm=grad_norm, dtype=torch.float32, ) + grad_norm = torch.tensor([grad_norm]) # Update parameters self.optimizer.step() diff --git a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py index 89b46fd6ac..a75f501393 100644 --- a/nemo_reinforcer/models/policy/fsdp1_policy_worker.py +++ b/nemo_reinforcer/models/policy/fsdp1_policy_worker.py @@ -310,10 +310,20 @@ def train( all_mb_metrics.append(loss_metrics) # Clip gradients + grad_norm = None if not eval_mode: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), max_norm=self.cfg["max_grad_norm"] - ) + if isinstance(self.model, FullyShardedDataParallel): + # when using FSDP1, use FSDP's clip_grad_norm_ + # to ensure grad norm is being computed over all parameters + # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ + grad_norm = self.model.clip_grad_norm_( + max_norm=self.cfg["max_grad_norm"] + ) + else: + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=self.cfg["max_grad_norm"] + ) + grad_norm = grad_norm.cpu() # Update parameters self.optimizer.step() @@ -336,6 +346,7 @@ def train( metrics = { "global_loss": global_loss.cpu(), "local_loss": local_loss.cpu(), + "grad_norm": grad_norm, "rank": torch.distributed.get_rank(), "all_mb_metrics": dict(mb_metrics), } diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index c4b0718288..0b56b002e7 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -177,8 +177,10 @@ def train( results = self.worker_group.get_all_worker_results(futures) # Aggregate the results - aggregated_results = {} - aggregated_results["loss"] = results[0]["global_loss"] + aggregated_results = { + "loss": results[0]["global_loss"], + "grad_norm": results[0]["grad_norm"], + } # Aggregate metrics across all workers all_mb_metrics = defaultdict(list) diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index 3b63c9a8b5..737108115f 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -24,7 +24,11 @@ def mock_components(): # Create mock components policy = MagicMock() - policy.train.return_value = {"loss": torch.tensor(0.5), "all_mb_metrics": {}} + policy.train.return_value = { + "loss": torch.tensor(0.5), + "grad_norm": torch.tensor(1.0), + "all_mb_metrics": {}, + } # Create a proper message log structure with token_ids mock_batch = {