Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 77 additions & 28 deletions unsloth_zoo/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"mean_token_accuracy", # SFT extras
"entropy", # TRL >= 0.22.0
"aux_loss", # TRL >= 0.23.0

# GRPO extras
"clip_ratio",
'clip_ratio/low_mean',
Expand All @@ -39,6 +39,9 @@
'clip_ratio/high_max',
'clip_ratio/region_mean',
'frac_reward_zero_std',

# Regex false positive from self._metrics["train"]["step_time"] in TRL >= 0.26.0
'train',
]
REMOVED_METRICS = frozenset(REMOVED_METRICS)

Expand All @@ -65,28 +68,52 @@ def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwarg
self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
self.training_loss = 0
self.last_log = 0
# Don't pre-create metric columns. Start with just the essentials;
# columns are added dynamically by write_line as metrics actually appear.
# This prevents empty "0 then blank" columns for conditional metrics
# (kl when beta=0, sampling/* without importance sampling, etc.)
column_names = [self.first_column] + ["Training Loss"]
if args.eval_strategy != IntervalStrategy.NO:
column_names.append("Validation Loss")
column_names += [x.replace("/", " / ") for x in Trainer_metrics]
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
pass
return _NotebookProgressCallback_on_train_begin
pass


def NotebookProgressCallback_on_log(Trainer_metrics):
# Build an allowlist of known metrics (pre-extracted from TRL source).
# Only these + dynamic rewards/* metrics pass through. This blocks
# spurious keys injected at runtime (e.g. "train", "tools/*").
set_Trainer_metrics = frozenset(Trainer_metrics)

def _NotebookProgressCallback_on_log(self, args, state, control, logs = None, **kwargs):
# Only for when there is no evaluation
if args.eval_strategy == IntervalStrategy.NO and "loss" in logs:
values = {"Training Loss": logs["loss"]}
values = {}

# 1) Pre-extracted metrics — only if actually present in logs
for metric in Trainer_metrics:
# Sometimes metric is not inside logs
try: values[metric.replace("/", " / ")] = logs[metric]
except: pass
if metric in logs:
values[metric.replace("/", " / ")] = logs[metric]
pass
# First column is necessarily Step since we're not in epoch eval strategy
values["Step"] = state.global_step

# 2) Dynamic per-reward-function metrics (rewards/*)
# These have user-defined names so can't be pre-extracted.
# Sort for stable column ordering across steps.
dynamic_reward_keys = sorted(
k for k in logs
if k.startswith("rewards/") and k not in set_Trainer_metrics
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for discovering dynamic metric keys only accounts for keys starting with rewards/. Based on the pull request description and comments, it seems conditional metrics like sampling/* should also be handled dynamically. To ensure they appear in the logs when available, you could expand this condition. You might also want to rename dynamic_reward_keys to something more generic like dynamic_keys for clarity.

Suggested change
if k.startswith("rewards/") and k not in set_Trainer_metrics
if (k.startswith("rewards/") or k.startswith("sampling/")) and k not in set_Trainer_metrics

)
for key in dynamic_reward_keys:
display_key = key.replace("/", " / ")
if display_key not in values:
values[display_key] = logs[key]
pass

# 3) Prepend Training Loss + Step (always first columns)
values = {"Training Loss": logs["loss"], **values}
values[self.first_column] = state.global_step
self.training_tracker.write_line(values)
pass
pass
Expand All @@ -95,7 +122,6 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs = None, **


def NotebookTrainingTracker_write_line(Trainer_metrics):
set_Trainer_metrics = set(Trainer_metrics)
def _NotebookTrainingTracker_write_line(self, values):
"""
Write the values in the inner table.
Expand All @@ -107,33 +133,34 @@ def _NotebookTrainingTracker_write_line(self, values):
self.inner_table = [list(values.keys()), list(values.values())]
else:
columns = self.inner_table[0]
new_values = {}
for key, value in values.items():
lowered = key.lower()
if lowered in set_Trainer_metrics:
new_values[lowered.replace("/", " / ")] = value
else:
new_values[key] = value

# Dynamically add new columns that appear in values
# (e.g. per-reward-func metrics discovered at step > 1)
for key in values:
if key not in columns:
columns.append(key)
# Back-fill previous rows with empty string
for row in self.inner_table[1:]:
row.append("")
pass
values = new_values

self.inner_table[0] = columns
first_column = columns[0]
if len(self.inner_table) > 1:
last_values = self.inner_table[-1]
first_column = self.inner_table[0][0]
if last_values[0] != values[first_column]:
# write new line
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
self.inner_table.append([values[c] if c in values else "" for c in columns])
else:
# update last line
# update last line — preserve existing values for missing keys
new_values = values
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Assigning values directly to new_values creates an alias, not a copy. The subsequent loop (lines 157-159) modifies new_values, which unexpectedly mutates the original values dictionary passed into this function. This side effect can lead to subtle bugs and makes the code harder to reason about. It's better to explicitly create a copy to avoid modifying the input argument.

Suggested change
new_values = values
new_values = values.copy()

for c in columns:
if c not in new_values.keys():
if c not in new_values:
new_values[c] = last_values[columns.index(c)]
self.inner_table[-1] = [new_values[c] for c in columns]
else:
# Edit for evaluation purposes
self.inner_table.append([values[c] if c in values else 0 for c in columns])
# First data row (after header)
self.inner_table.append([values[c] if c in values else "" for c in columns])
pass
pass
pass
Expand Down Expand Up @@ -166,10 +193,32 @@ def get_trl_metrics():
filepath = inspect.getfile(trl.trainer)
filepath = os.path.split(filepath)[0]

all_metrics = dict()
# TRL >= 0.26.0 moved many trainers to trl/experimental/*/
# The old trl/trainer/ files become thin shims that re-export.
# Build a map of trainer_name -> source file path, preferring the
# experimental (real) file when both exist.
trl_root = os.path.split(filepath)[0]
exp_dir = os.path.join(trl_root, "experimental")
trainer_files = dict()
for trainer in trainers:
filename = os.path.join(filepath, f"{trainer}.py")
if not os.path.exists(filename): continue
candidates = []
# 1) trl/trainer/{trainer}.py (original or shim)
c1 = os.path.join(filepath, f"{trainer}.py")
if os.path.exists(c1):
candidates.append(c1)
# 2) trl/experimental/{name}/{trainer}.py (real code in >= 0.26.0)
if os.path.isdir(exp_dir):
name = trainer.replace("_trainer", "")
c2 = os.path.join(exp_dir, name, f"{trainer}.py")
if os.path.exists(c2):
candidates.append(c2)
# Prefer the larger file (real code vs thin shim)
if candidates:
trainer_files[trainer] = max(candidates, key = os.path.getsize)
pass

all_metrics = dict()
for trainer, filename in trainer_files.items():
with open(filename, "r", encoding = "utf-8") as file: file = file.read()

# Get metrics['kl'] or stats['kl']
Expand All @@ -182,12 +231,12 @@ def get_trl_metrics():
stats2 = re.findall(r"stats\[mode\]\[[\"\']([^\"\']{1,})[\"\']\]", file)
metrics = metrics + metrics2 + stats2

# Get optional f-strings
# Get optional f-strings (variable at start: f"{var}suffix")
metrics_f = re.findall(r"_?metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file)
stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file)
metrics_f = metrics_f + stats_f

# Get optional f-strings for new TRL [mode]
# Get optional f-strings for new TRL [mode] (variable at start)
metrics_f2 = re.findall(r"_?metrics\[mode\]\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file)
stats_f2 = re.findall(r"stats\[mode\]\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file)
metrics_f = metrics_f + metrics_f2 + stats_f2
Expand Down
6 changes: 3 additions & 3 deletions unsloth_zoo/rl_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,9 @@ def compute_loss(new_logps, old_logps, ref_logps, sampling_per_token_logps, inpu

device =_new_logps.device
grad_inputs = torch.empty_like(_new_logps)
accumulated_loss = torch.zeros(1, device = device)
accumulated_completion_length = torch.zeros(1, device = device)
accumulated_mean_kl = torch.zeros(1, device = device)
accumulated_loss = torch.zeros(1, device = device)[0]
Copy link
Copy Markdown
Collaborator

@Datta0 Datta0 Mar 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably do
accumulated_loss = torch.zeros((), device = device) instead?
Ref

accumulated_completion_length = torch.zeros(1, device = device)[0]
accumulated_mean_kl = torch.zeros(1, device = device)[0]
Comment on lines +524 to +526
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While torch.zeros(1, device=device)[0] correctly creates a scalar tensor, it's a bit indirect. Using torch.tensor(0.0, device=device) is more idiomatic and clearly expresses the intent to create a scalar. This would improve code readability and maintainability.

Suggested change
accumulated_loss = torch.zeros(1, device = device)[0]
accumulated_completion_length = torch.zeros(1, device = device)[0]
accumulated_mean_kl = torch.zeros(1, device = device)[0]
accumulated_loss = torch.tensor(0.0, device = device)
accumulated_completion_length = torch.tensor(0.0, device = device)
accumulated_mean_kl = torch.tensor(0.0, device = device)

accumulated_delta = []
accumulated_flat_is_ratio = []
accumulated_coef_1 = []
Expand Down