Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def generate_responses(
tokenizer,
input_lengths: torch.Tensor,
include_logprobs: bool = True,
) -> Tuple[List[torch.Tensor], List[str], torch.Tensor]:
) -> Tuple[BatchedDataDict[DatumSpec], List[List[int]], Dict[str, float | int]]:
"""Generate responses from policy."""
# Generate responses
generation_outputs = policy_generation.generate(generation_input_data)
Expand Down Expand Up @@ -451,6 +451,7 @@ def grpo_train(
logger.log_metrics(validation_timings, step, prefix="timing/validation")

# Run grpo training (single-turn)
batch: BatchedDataDict[DatumSpec]
for batch in dataloader:
print(
f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}"
Expand Down Expand Up @@ -638,6 +639,14 @@ def grpo_train(
policy.offload_after_refit()

# Logging
# Log training data
log_data = {"content": flat_messages["content"]}
log_data["rewards"] = rewards.tolist()
log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist()
log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist()
log_data["input_lengths"] = input_lengths.tolist()
logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl")

print("\n📊 Training Results:")
metrics = {
"loss": train_results["loss"].numpy(),
Expand Down
29 changes: 29 additions & 0 deletions nemo_reinforcer/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import threading
import requests
import json
from abc import ABC, abstractmethod
import logging
from typing import List, Any, Dict, Optional, TypedDict, Union
Expand All @@ -27,8 +28,10 @@
from rich.panel import Panel
from rich.box import ROUNDED
from rich.logging import RichHandler
import torch

from nemo_reinforcer.data.interfaces import LLMMessageLogType
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
from torch.utils.tensorboard import SummaryWriter

import ray
Expand Down Expand Up @@ -568,6 +571,32 @@ def log_hyperparams(self, params: Dict[str, Any]) -> None:
for logger in self.loggers:
logger.log_hyperparams(params)

def log_batched_dict_as_jsonl(
self, to_log: BatchedDataDict | Dict[str, Any], filename: str
) -> None:
"""Log a list of dictionaries to a JSONL file.

Args:
to_log: BatchedDataDict to log
filename: Filename to log to (within the log directory)
"""
if not isinstance(to_log, BatchedDataDict):
to_log = BatchedDataDict(to_log)

# Create full path within log directory
filepath = os.path.join(self.base_log_dir, filename)
os.makedirs(os.path.dirname(filepath), exist_ok=True)

# Write to JSONL file
with open(filepath, "w") as f:
for i, sample in enumerate(to_log.make_microbatch_iterator(1)):
for key, value in sample.items():
if isinstance(value, torch.Tensor):
sample[key] = value.tolist()
f.write(json.dumps({**sample, "idx": i}) + "\n")

print(f"Logged data to {filepath}")

def __del__(self):
"""Clean up resources when the logger is destroyed."""
if self.gpu_monitor:
Expand Down
Loading