From 262a242d25a896adbc9a3f3f3703b5a37d0f4ace Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Fri, 11 Apr 2025 16:07:03 -0700 Subject: [PATCH] Add total logging of generations in training Signed-off-by: Sahil Jain --- nemo_reinforcer/algorithms/grpo.py | 11 ++++++++++- nemo_reinforcer/utils/logger.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 0eda853375..f56062fa74 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -289,7 +289,7 @@ def generate_responses( tokenizer, input_lengths: torch.Tensor, include_logprobs: bool = True, -) -> Tuple[List[torch.Tensor], List[str], torch.Tensor]: +) -> Tuple[BatchedDataDict[DatumSpec], List[List[int]], Dict[str, float | int]]: """Generate responses from policy.""" # Generate responses generation_outputs = policy_generation.generate(generation_input_data) @@ -451,6 +451,7 @@ def grpo_train( logger.log_metrics(validation_timings, step, prefix="timing/validation") # Run grpo training (single-turn) + batch: BatchedDataDict[DatumSpec] for batch in dataloader: print( f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}" @@ -638,6 +639,14 @@ def grpo_train( policy.offload_after_refit() # Logging + # Log training data + log_data = {"content": flat_messages["content"]} + log_data["rewards"] = rewards.tolist() + log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() + log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() + log_data["input_lengths"] = input_lengths.tolist() + logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl") + print("\nšŸ“Š Training Results:") metrics = { "loss": train_results["loss"].numpy(), diff --git a/nemo_reinforcer/utils/logger.py b/nemo_reinforcer/utils/logger.py index bc0157d564..0442854576 100644 --- a/nemo_reinforcer/utils/logger.py +++ b/nemo_reinforcer/utils/logger.py @@ -19,6 +19,7 @@ import time import threading import requests +import json from abc import ABC, abstractmethod import logging from typing import List, Any, Dict, Optional, TypedDict, Union @@ -27,8 +28,10 @@ from rich.panel import Panel from rich.box import ROUNDED from rich.logging import RichHandler +import torch from nemo_reinforcer.data.interfaces import LLMMessageLogType +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from torch.utils.tensorboard import SummaryWriter import ray @@ -568,6 +571,32 @@ def log_hyperparams(self, params: Dict[str, Any]) -> None: for logger in self.loggers: logger.log_hyperparams(params) + def log_batched_dict_as_jsonl( + self, to_log: BatchedDataDict | Dict[str, Any], filename: str + ) -> None: + """Log a list of dictionaries to a JSONL file. + + Args: + to_log: BatchedDataDict to log + filename: Filename to log to (within the log directory) + """ + if not isinstance(to_log, BatchedDataDict): + to_log = BatchedDataDict(to_log) + + # Create full path within log directory + filepath = os.path.join(self.base_log_dir, filename) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + # Write to JSONL file + with open(filepath, "w") as f: + for i, sample in enumerate(to_log.make_microbatch_iterator(1)): + for key, value in sample.items(): + if isinstance(value, torch.Tensor): + sample[key] = value.tolist() + f.write(json.dumps({**sample, "idx": i}) + "\n") + + print(f"Logged data to {filepath}") + def __del__(self): """Clean up resources when the logger is destroyed.""" if self.gpu_monitor: