-
Notifications
You must be signed in to change notification settings - Fork 261
Fix GRPO notebook logging and transformers v5 loss shape mismatch #543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c965c34
76a4c6b
50c7405
c24b6f9
d0362fd
ea7d3a1
d0cadd2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -30,7 +30,7 @@ | |||||
| "mean_token_accuracy", # SFT extras | ||||||
| "entropy", # TRL >= 0.22.0 | ||||||
| "aux_loss", # TRL >= 0.23.0 | ||||||
|
|
||||||
| # GRPO extras | ||||||
| "clip_ratio", | ||||||
| 'clip_ratio/low_mean', | ||||||
|
|
@@ -39,6 +39,9 @@ | |||||
| 'clip_ratio/high_max', | ||||||
| 'clip_ratio/region_mean', | ||||||
| 'frac_reward_zero_std', | ||||||
|
|
||||||
| # Regex false positive from self._metrics["train"]["step_time"] in TRL >= 0.26.0 | ||||||
| 'train', | ||||||
| ] | ||||||
| REMOVED_METRICS = frozenset(REMOVED_METRICS) | ||||||
|
|
||||||
|
|
@@ -65,28 +68,52 @@ def _NotebookProgressCallback_on_train_begin(self, args, state, control, **kwarg | |||||
| self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step" | ||||||
| self.training_loss = 0 | ||||||
| self.last_log = 0 | ||||||
| # Don't pre-create metric columns. Start with just the essentials; | ||||||
| # columns are added dynamically by write_line as metrics actually appear. | ||||||
| # This prevents empty "0 then blank" columns for conditional metrics | ||||||
| # (kl when beta=0, sampling/* without importance sampling, etc.) | ||||||
| column_names = [self.first_column] + ["Training Loss"] | ||||||
| if args.eval_strategy != IntervalStrategy.NO: | ||||||
| column_names.append("Validation Loss") | ||||||
| column_names += [x.replace("/", " / ") for x in Trainer_metrics] | ||||||
| self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names) | ||||||
| pass | ||||||
| return _NotebookProgressCallback_on_train_begin | ||||||
| pass | ||||||
|
|
||||||
|
|
||||||
| def NotebookProgressCallback_on_log(Trainer_metrics): | ||||||
| # Build an allowlist of known metrics (pre-extracted from TRL source). | ||||||
| # Only these + dynamic rewards/* metrics pass through. This blocks | ||||||
| # spurious keys injected at runtime (e.g. "train", "tools/*"). | ||||||
| set_Trainer_metrics = frozenset(Trainer_metrics) | ||||||
|
|
||||||
| def _NotebookProgressCallback_on_log(self, args, state, control, logs = None, **kwargs): | ||||||
| # Only for when there is no evaluation | ||||||
| if args.eval_strategy == IntervalStrategy.NO and "loss" in logs: | ||||||
| values = {"Training Loss": logs["loss"]} | ||||||
| values = {} | ||||||
|
|
||||||
| # 1) Pre-extracted metrics — only if actually present in logs | ||||||
| for metric in Trainer_metrics: | ||||||
| # Sometimes metric is not inside logs | ||||||
| try: values[metric.replace("/", " / ")] = logs[metric] | ||||||
| except: pass | ||||||
| if metric in logs: | ||||||
| values[metric.replace("/", " / ")] = logs[metric] | ||||||
| pass | ||||||
| # First column is necessarily Step since we're not in epoch eval strategy | ||||||
| values["Step"] = state.global_step | ||||||
|
|
||||||
| # 2) Dynamic per-reward-function metrics (rewards/*) | ||||||
| # These have user-defined names so can't be pre-extracted. | ||||||
| # Sort for stable column ordering across steps. | ||||||
| dynamic_reward_keys = sorted( | ||||||
| k for k in logs | ||||||
| if k.startswith("rewards/") and k not in set_Trainer_metrics | ||||||
| ) | ||||||
| for key in dynamic_reward_keys: | ||||||
| display_key = key.replace("/", " / ") | ||||||
| if display_key not in values: | ||||||
| values[display_key] = logs[key] | ||||||
| pass | ||||||
|
|
||||||
| # 3) Prepend Training Loss + Step (always first columns) | ||||||
| values = {"Training Loss": logs["loss"], **values} | ||||||
| values[self.first_column] = state.global_step | ||||||
| self.training_tracker.write_line(values) | ||||||
| pass | ||||||
| pass | ||||||
|
|
@@ -95,7 +122,6 @@ def _NotebookProgressCallback_on_log(self, args, state, control, logs = None, ** | |||||
|
|
||||||
|
|
||||||
| def NotebookTrainingTracker_write_line(Trainer_metrics): | ||||||
| set_Trainer_metrics = set(Trainer_metrics) | ||||||
| def _NotebookTrainingTracker_write_line(self, values): | ||||||
| """ | ||||||
| Write the values in the inner table. | ||||||
|
|
@@ -107,33 +133,34 @@ def _NotebookTrainingTracker_write_line(self, values): | |||||
| self.inner_table = [list(values.keys()), list(values.values())] | ||||||
| else: | ||||||
| columns = self.inner_table[0] | ||||||
| new_values = {} | ||||||
| for key, value in values.items(): | ||||||
| lowered = key.lower() | ||||||
| if lowered in set_Trainer_metrics: | ||||||
| new_values[lowered.replace("/", " / ")] = value | ||||||
| else: | ||||||
| new_values[key] = value | ||||||
|
|
||||||
| # Dynamically add new columns that appear in values | ||||||
| # (e.g. per-reward-func metrics discovered at step > 1) | ||||||
| for key in values: | ||||||
| if key not in columns: | ||||||
| columns.append(key) | ||||||
| # Back-fill previous rows with empty string | ||||||
| for row in self.inner_table[1:]: | ||||||
| row.append("") | ||||||
| pass | ||||||
| values = new_values | ||||||
|
|
||||||
| self.inner_table[0] = columns | ||||||
| first_column = columns[0] | ||||||
| if len(self.inner_table) > 1: | ||||||
| last_values = self.inner_table[-1] | ||||||
| first_column = self.inner_table[0][0] | ||||||
| if last_values[0] != values[first_column]: | ||||||
| # write new line | ||||||
| self.inner_table.append([values[c] if c in values else "No Log" for c in columns]) | ||||||
| self.inner_table.append([values[c] if c in values else "" for c in columns]) | ||||||
| else: | ||||||
| # update last line | ||||||
| # update last line — preserve existing values for missing keys | ||||||
| new_values = values | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assigning
Suggested change
|
||||||
| for c in columns: | ||||||
| if c not in new_values.keys(): | ||||||
| if c not in new_values: | ||||||
| new_values[c] = last_values[columns.index(c)] | ||||||
| self.inner_table[-1] = [new_values[c] for c in columns] | ||||||
| else: | ||||||
| # Edit for evaluation purposes | ||||||
| self.inner_table.append([values[c] if c in values else 0 for c in columns]) | ||||||
| # First data row (after header) | ||||||
| self.inner_table.append([values[c] if c in values else "" for c in columns]) | ||||||
| pass | ||||||
| pass | ||||||
| pass | ||||||
|
|
@@ -166,10 +193,32 @@ def get_trl_metrics(): | |||||
| filepath = inspect.getfile(trl.trainer) | ||||||
| filepath = os.path.split(filepath)[0] | ||||||
|
|
||||||
| all_metrics = dict() | ||||||
| # TRL >= 0.26.0 moved many trainers to trl/experimental/*/ | ||||||
| # The old trl/trainer/ files become thin shims that re-export. | ||||||
| # Build a map of trainer_name -> source file path, preferring the | ||||||
| # experimental (real) file when both exist. | ||||||
| trl_root = os.path.split(filepath)[0] | ||||||
| exp_dir = os.path.join(trl_root, "experimental") | ||||||
| trainer_files = dict() | ||||||
| for trainer in trainers: | ||||||
| filename = os.path.join(filepath, f"{trainer}.py") | ||||||
| if not os.path.exists(filename): continue | ||||||
| candidates = [] | ||||||
| # 1) trl/trainer/{trainer}.py (original or shim) | ||||||
| c1 = os.path.join(filepath, f"{trainer}.py") | ||||||
| if os.path.exists(c1): | ||||||
| candidates.append(c1) | ||||||
| # 2) trl/experimental/{name}/{trainer}.py (real code in >= 0.26.0) | ||||||
| if os.path.isdir(exp_dir): | ||||||
| name = trainer.replace("_trainer", "") | ||||||
| c2 = os.path.join(exp_dir, name, f"{trainer}.py") | ||||||
| if os.path.exists(c2): | ||||||
| candidates.append(c2) | ||||||
| # Prefer the larger file (real code vs thin shim) | ||||||
| if candidates: | ||||||
| trainer_files[trainer] = max(candidates, key = os.path.getsize) | ||||||
| pass | ||||||
|
|
||||||
| all_metrics = dict() | ||||||
| for trainer, filename in trainer_files.items(): | ||||||
| with open(filename, "r", encoding = "utf-8") as file: file = file.read() | ||||||
|
|
||||||
| # Get metrics['kl'] or stats['kl'] | ||||||
|
|
@@ -182,12 +231,12 @@ def get_trl_metrics(): | |||||
| stats2 = re.findall(r"stats\[mode\]\[[\"\']([^\"\']{1,})[\"\']\]", file) | ||||||
| metrics = metrics + metrics2 + stats2 | ||||||
|
|
||||||
| # Get optional f-strings | ||||||
| # Get optional f-strings (variable at start: f"{var}suffix") | ||||||
| metrics_f = re.findall(r"_?metrics\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) | ||||||
| stats_f = re.findall(r"stats\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) | ||||||
| metrics_f = metrics_f + stats_f | ||||||
|
|
||||||
| # Get optional f-strings for new TRL [mode] | ||||||
| # Get optional f-strings for new TRL [mode] (variable at start) | ||||||
| metrics_f2 = re.findall(r"_?metrics\[mode\]\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) | ||||||
| stats_f2 = re.findall(r"stats\[mode\]\[f[\"\']\{[^\}]{1,}\}([^\"\']{1,})[\"\']\]", file) | ||||||
| metrics_f = metrics_f + metrics_f2 + stats_f2 | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -521,9 +521,9 @@ def compute_loss(new_logps, old_logps, ref_logps, sampling_per_token_logps, inpu | |||||||||||||
|
|
||||||||||||||
| device =_new_logps.device | ||||||||||||||
| grad_inputs = torch.empty_like(_new_logps) | ||||||||||||||
| accumulated_loss = torch.zeros(1, device = device) | ||||||||||||||
| accumulated_completion_length = torch.zeros(1, device = device) | ||||||||||||||
| accumulated_mean_kl = torch.zeros(1, device = device) | ||||||||||||||
| accumulated_loss = torch.zeros(1, device = device)[0] | ||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably do |
||||||||||||||
| accumulated_completion_length = torch.zeros(1, device = device)[0] | ||||||||||||||
| accumulated_mean_kl = torch.zeros(1, device = device)[0] | ||||||||||||||
|
Comment on lines
+524
to
+526
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While
Suggested change
|
||||||||||||||
| accumulated_delta = [] | ||||||||||||||
| accumulated_flat_is_ratio = [] | ||||||||||||||
| accumulated_coef_1 = [] | ||||||||||||||
|
|
||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for discovering dynamic metric keys only accounts for keys starting with
rewards/. Based on the pull request description and comments, it seems conditional metrics likesampling/*should also be handled dynamically. To ensure they appear in the logs when available, you could expand this condition. You might also want to renamedynamic_reward_keysto something more generic likedynamic_keysfor clarity.