feat: log generation ISL/OSL histogram to wandb#1594
Conversation
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
📝 WalkthroughWalkthroughThis PR enhances observability of sequence length distributions during GRPO training by adding histogram-level logging to wandb. It introduces a new logging function for histogram metrics, integrates it into training loops, and extends token accounting to track per-turn input and total token lengths alongside existing generation metrics. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/experience/rollouts.py (1)
657-693: Convertinput_lengthstensor to a Python scalar once per turn to avoid mixing tensor and int types in arithmetic and slicing
input_lengthsis a torch.Tensor returned fromasync_generate_response_for_sample_turn, but it is handled inconsistently within the same loop iteration:
- Line 689:
turn_input_tokens.append(int(input_lengths))— converts to Python int- Line 728:
if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len:— uses tensor directly in arithmetic- Line 729:
max_env_tokens = max_seq_len - input_lengths - gen_token_count— tensor arithmetic- Line 730:
tokenized_obs = tokenized_obs[:max_env_tokens]— slices with result of tensor arithmeticThis fragile mixing relies on implicit PyTorch broadcasting and the
__index__protocol for slicing. Convert to a scalar once per turn and reuse for consistency:current_message_log = updated_message_log + # Convert tensor to Python scalar for this turn + input_length = int(input_lengths.item()) + # Update token counts gen_token_count = len(generated_tokens) assistant_token_count += gen_token_count token_count += gen_token_count turn_gen_tokens.append(gen_token_count) - turn_input_tokens.append(int(input_lengths)) - turn_total_tokens.append(int(input_lengths) + gen_token_count) + turn_input_tokens.append(input_length) + turn_total_tokens.append(input_length + gen_token_count) # Per-worker load accountingAnd at the overflow check:
- if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len: + if input_length + gen_token_count + len(tokenized_obs) >= max_seq_len: # Truncate environment observation - max_env_tokens = max_seq_len - input_lengths - gen_token_count + max_env_tokens = max_seq_len - input_length - gen_token_count
🧹 Nitpick comments (2)
nemo_rl/experience/rollouts.py (1)
940-949: Flatten per‑sample turn length lists withoutreduce(lambda x, y: x + y, …)The current aggregation:
rollout_metrics["gen_tokens_lengths"] = reduce( lambda x, y: x + y, [m["turn_gen_tokens"] for m in all_sample_metrics] )is correct given each
turn_*is alist[int], butreducewith list concatenation is less readable and can be quadratic for long lists.Consider a clearer, linear‑time flatten using
itertools.chain.from_iterable:+from itertools import chain @@ - rollout_metrics["gen_tokens_lengths"] = reduce( - lambda x, y: x + y, [m["turn_gen_tokens"] for m in all_sample_metrics] - ) - rollout_metrics["input_tokens_lengths"] = reduce( - lambda x, y: x + y, [m["turn_input_tokens"] for m in all_sample_metrics] - ) - rollout_metrics["total_tokens_lengths"] = reduce( - lambda x, y: x + y, [m["turn_total_tokens"] for m in all_sample_metrics] - ) + rollout_metrics["gen_tokens_lengths"] = list( + chain.from_iterable(m["turn_gen_tokens"] for m in all_sample_metrics) + ) + rollout_metrics["input_tokens_lengths"] = list( + chain.from_iterable(m["turn_input_tokens"] for m in all_sample_metrics) + ) + rollout_metrics["total_tokens_lengths"] = list( + chain.from_iterable(m["turn_total_tokens"] for m in all_sample_metrics) + )Functionally equivalent, but simpler and more efficient for large batches/turn counts.
nemo_rl/algorithms/grpo.py (1)
1570-1585: Make histogram logging more robust (avoid broadExceptionand ensure metrics are sequences)Two related points around the new ISL/OSL/ISL+OSL histogram logging:
- Blind
except Exceptionand missing‑key handlingIn both
grpo_trainandasync_grpo_trainyou have:# Plot ISL/OSL/ISL+OSL histograms to wandb try: for hist_metrics in [ "gen_tokens_lengths", "input_tokens_lengths", "total_tokens_lengths", ]: log_histogram_metrics_to_wandb( f"generation_metrics/{hist_metrics}", metrics[hist_metrics], ..., ) except Exception as e: print(f"❌ Error plotting histograms to wandb: {e}") passThis pattern:
- Relies on
KeyErrorfor missing metrics being swallowed by a broadexcept Exception, which is what Ruff BLE001 is flagging.- Also hides genuine wandb/logging bugs; training continues, but you only get a console print.
You can avoid both the broad catch and the missing‑key reliance by checking for the metric’s presence before logging:
- # Plot ISL/OSL/ISL+OSL histograms to wandb - try: - for hist_metrics in [ - "gen_tokens_lengths", - "input_tokens_lengths", - "total_tokens_lengths", - ]: - log_histogram_metrics_to_wandb( - f"generation_metrics/{hist_metrics}", - metrics[hist_metrics], - total_steps + 1, - logger, - ) - except Exception as e: - print(f"❌ Error plotting histograms to wandb: {e}") - pass + # Plot ISL/OSL/ISL+OSL histograms to wandb when available + for hist_metric in ( + "gen_tokens_lengths", + "input_tokens_lengths", + "total_tokens_lengths", + ): + hist_values = metrics.get(hist_metric) + if hist_values is None: + continue + log_histogram_metrics_to_wandb( + f"generation_metrics/{hist_metric}", + hist_values, + total_steps + 1, + logger, + )and analogously in
async_grpo_train(usingstep + 1for the step). This satisfies the guideline to avoid blindexcept Exceptionand makes the control flow clearer.
- Async GRPO aggregation vs histogram expectations
In
async_grpo_train, per‑trajectory rollout metrics are aggregated as:rollout_metrics = {} for t in trajectories: for k, v in t["rollout_metrics"].items(): rollout_metrics.setdefault(k, []).append(v) # ... rollout_metrics = { k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v) for k, v in rollout_metrics.items() }If per‑trajectory ISL/OSL metrics are sequences (e.g., lists of token lengths per turn), the above will either:
- Treat them as scalars and average (if
v[0]is anint), giving a single float instead of a list, or- Leave you with a list‑of‑lists if each
vis already a list.Either way, this may not match what
wandb.Histogramexpects (a 1‑D list/array of values) when you later passmetrics["gen_tokens_lengths"]intolog_histogram_metrics_to_wandb.I recommend special‑casing these histogram metrics at aggregation time so they are flattened across trajectories instead of averaged:
@@ - rollout_metrics = { - k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v) - for k, v in rollout_metrics.items() - } + rollout_metrics = { + k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v) + for k, v in rollout_metrics.items() + } + + # Ensure histogram metrics are flattened sequences suitable for wandb.Histogram + for hist_key in ( + "gen_tokens_lengths", + "input_tokens_lengths", + "total_tokens_lengths", + ): + if hist_key in rollout_metrics: + values = rollout_metrics[hist_key] + # If we accumulated per‑trajectory lists, flatten them + if values and isinstance(values[0], (list, tuple)): + rollout_metrics[hist_key] = [ + x for group in values for x in group + ]Please double‑check the exact structure of
t["rollout_metrics"]in the async collector to confirm whether flattening like this is needed for your case; the important part is that by the time you calllog_histogram_metrics_to_wandb,metrics[hist_key]should be a flat list/array of numeric values, not an average scalar or nested lists.Also applies to: 2130-2138, 2510-2525
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
nemo_rl/algorithms/grpo.py(3 hunks)nemo_rl/algorithms/utils.py(1 hunks)nemo_rl/experience/rollouts.py(5 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/algorithms/grpo.pynemo_rl/experience/rollouts.pynemo_rl/algorithms/utils.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/algorithms/grpo.pynemo_rl/experience/rollouts.pynemo_rl/algorithms/utils.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/algorithms/grpo.pynemo_rl/experience/rollouts.pynemo_rl/algorithms/utils.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/algorithms/grpo.pynemo_rl/experience/rollouts.pynemo_rl/algorithms/utils.py
🧬 Code graph analysis (2)
nemo_rl/algorithms/grpo.py (1)
nemo_rl/algorithms/utils.py (1)
log_histogram_metrics_to_wandb(774-793)
nemo_rl/algorithms/utils.py (1)
nemo_rl/utils/logger.py (2)
Logger(805-1123)log_plot_per_worker_timeline_metrics(939-1018)
🪛 Ruff (0.14.7)
nemo_rl/algorithms/grpo.py
1583-1583: Do not catch blind exception: Exception
(BLE001)
2523-2523: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: sphinx-build / Build docs
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/utils.py (1)
750-772: Generation and histogram logging helpers look consistent with Logger usageThe refactored
log_generation_metrics_to_wandband newlog_histogram_metrics_to_wandbboth align with howLoggeris structured:
generation_logger_metrics: dict[str, dict[int, list[Any]]]matches whatlog_plot_per_worker_timeline_metricsexpects (inner dict is per‑worker series).- The
generation_metrics/…prefix keeps naming consistent with the new histogram metric names ingrpo.py.log_histogram_metrics_to_wandbis correctly guarded onlogger.wandb_loggerand imports wandb lazily, so it’s safe to call unconditionally from training code.No functional issues from this diff; the helpers are ready to be used by the GRPO training loops.
Also applies to: 774-793
|
Hi @guyueh1, do you think this implementation meets your expectations? |
|
Hi @terrykong, can I ask for your review and merge for this PR? |
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
|
Hi @terrykong, I have addressed most of your comments in the newer version. Could you please take a look? |
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
What does this PR do ?
Sample wandb report with this PR: report
You can also download the histogram data from wandb.

Issues
List issues that this PR closes (syntax):
Closes #1493
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.