Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions nemo_reinforcer/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import time
import threading
import requests
import subprocess
from abc import ABC, abstractmethod
import logging
from typing import List, Any, Dict, Optional, TypedDict, Union
Expand Down Expand Up @@ -125,10 +126,118 @@ class WandbLogger(LoggerInterface):

def __init__(self, cfg: WandbConfig, log_dir: Optional[str] = None):
self.run = wandb.init(**cfg, dir=log_dir)
self._log_code()
Comment thread
parthchadha marked this conversation as resolved.
self._log_diffs()
print(
f"Initialized WandbLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}"
)

def _log_diffs(self):
"""Log git diffs to wandb.

This function captures and logs two types of diffs:
1. Uncommitted changes (working tree diff against HEAD)
2. All changes (including uncommitted) against the main branch

Each diff is saved as a text file in a wandb artifact.
Comment thread
SahilJain314 marked this conversation as resolved.
"""
try:
branch_result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
capture_output=True,
text=True,
check=True,
)
current_branch = branch_result.stdout.strip()

diff_artifact = wandb.Artifact(
name=f"git-diffs-{self.run.project}-{self.run.id}", type="git-diffs"
)

# 1. Log uncommitted changes (working tree diff)
uncommitted_result = subprocess.run(
["git", "diff", "HEAD"], capture_output=True, text=True, check=True
)
uncommitted_diff = uncommitted_result.stdout

if uncommitted_diff:
diff_path = os.path.join(
wandb.run.dir if wandb.run else ".", "uncommitted_changes_diff.txt"
)
with open(diff_path, "w") as f:
f.write(uncommitted_diff)

# Add file to artifact
diff_artifact.add_file(diff_path, name="uncommitted_changes_diff.txt")
print("Logged uncommitted changes diff to wandb")
else:
print("No uncommitted changes found")

# 2. Log diff against main branch (if current branch is not main)
if current_branch != "main":
# Log diff between main and working tree (includes uncommitted changes)
working_diff_result = subprocess.run(
["git", "diff", "main"], capture_output=True, text=True, check=True
)
working_diff = working_diff_result.stdout

if working_diff:
# Save diff to a temporary file
diff_path = os.path.join(
wandb.run.dir if wandb.run else ".", "main_diff.txt"
)
with open(diff_path, "w") as f:
f.write(working_diff)

# Add file to artifact
diff_artifact.add_file(diff_path, name="main_diff.txt")
print("Logged diff against main branch")
else:
print("No differences found between main and working tree")

self.run.log_artifact(diff_artifact)

except subprocess.CalledProcessError as e:
print(f"Error during git operations: {e}")
except Exception as e:
print(f"Unexpected error during git diff logging: {e}")

def _log_code(self):
"""Log code that is tracked by git to wandb.

This function gets a list of all files tracked by git in the project root
and manually uploads them to the current wandb run as an artifact.
"""
try:
result = subprocess.run(
["git", "ls-files"], capture_output=True, text=True, check=True
)

tracked_files = result.stdout.strip().split("\n")

if not tracked_files:
print("No git-tracked files found")
Comment thread
parthchadha marked this conversation as resolved.
Outdated
return

code_artifact = wandb.Artifact(
name=f"source-code-{self.run.project}", type="code"
)

for file_path in tracked_files:
if os.path.isfile(file_path):
try:
code_artifact.add_file(file_path, name=file_path)
except Exception as e:
print(f"Error adding file {file_path}: {e}")

self.run.log_artifact(code_artifact)
print(f"Logged {len(tracked_files)} git-tracked files to wandb")

except subprocess.CalledProcessError as e:
print(f"Error getting git-tracked files: {e}")
except Exception as e:
print(f"Unexpected error during git code logging: {e}")

def define_metric(
self,
name: str,
Expand Down