Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import os
import re
import subprocess
import threading
import time
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -138,10 +139,120 @@ class WandbLogger(LoggerInterface):

def __init__(self, cfg: WandbConfig, log_dir: Optional[str] = None):
self.run = wandb.init(**cfg, dir=log_dir)
self._log_code()
self._log_diffs()
print(
f"Initialized WandbLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}"
)

def _log_diffs(self):
"""Log git diffs to wandb.

This function captures and logs two types of diffs:
1. Uncommitted changes (working tree diff against HEAD)
2. All changes (including uncommitted) against the main branch

Each diff is saved as a text file in a wandb artifact.
"""
try:
branch_result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
capture_output=True,
text=True,
check=True,
)
current_branch = branch_result.stdout.strip()

diff_artifact = wandb.Artifact(
name=f"git-diffs-{self.run.project}-{self.run.id}", type="git-diffs"
)

# 1. Log uncommitted changes (working tree diff)
uncommitted_result = subprocess.run(
["git", "diff", "HEAD"], capture_output=True, text=True, check=True
)
uncommitted_diff = uncommitted_result.stdout

if uncommitted_diff:
diff_path = os.path.join(
wandb.run.dir if wandb.run else ".", "uncommitted_changes_diff.txt"
)
with open(diff_path, "w") as f:
f.write(uncommitted_diff)

# Add file to artifact
diff_artifact.add_file(diff_path, name="uncommitted_changes_diff.txt")
print("Logged uncommitted changes diff to wandb")
else:
print("No uncommitted changes found")

# 2. Log diff against main branch (if current branch is not main)
if current_branch != "main":
# Log diff between main and working tree (includes uncommitted changes)
working_diff_result = subprocess.run(
["git", "diff", "main"], capture_output=True, text=True, check=True
)
working_diff = working_diff_result.stdout

if working_diff:
# Save diff to a temporary file
diff_path = os.path.join(
wandb.run.dir if wandb.run else ".", "main_diff.txt"
)
with open(diff_path, "w") as f:
f.write(working_diff)

# Add file to artifact
diff_artifact.add_file(diff_path, name="main_diff.txt")
print("Logged diff against main branch")
else:
print("No differences found between main and working tree")

self.run.log_artifact(diff_artifact)

except subprocess.CalledProcessError as e:
print(f"Error during git operations: {e}")
except Exception as e:
print(f"Unexpected error during git diff logging: {e}")

def _log_code(self):
"""Log code that is tracked by git to wandb.

This function gets a list of all files tracked by git in the project root
and manually uploads them to the current wandb run as an artifact.
"""
try:
result = subprocess.run(
["git", "ls-files"], capture_output=True, text=True, check=True
)

tracked_files = result.stdout.strip().split("\n")

if not tracked_files:
print(
"Warning: No git repository found. Wandb logs will not track code changes for reproducibility."
)
return

code_artifact = wandb.Artifact(
name=f"source-code-{self.run.project}", type="code"
)

for file_path in tracked_files:
if os.path.isfile(file_path):
try:
code_artifact.add_file(file_path, name=file_path)
except Exception as e:
print(f"Error adding file {file_path}: {e}")

self.run.log_artifact(code_artifact)
print(f"Logged {len(tracked_files)} git-tracked files to wandb")

except subprocess.CalledProcessError as e:
print(f"Error getting git-tracked files: {e}")
except Exception as e:
print(f"Unexpected error during git code logging: {e}")

def define_metric(
self,
name: str,
Expand Down
Loading