Skip to content

feat: log generation ISL/OSL histogram to wandb#1594

Merged
terrykong merged 5 commits intomainfrom
youngeunk/generation_histogram
Dec 5, 2025
Merged

feat: log generation ISL/OSL histogram to wandb#1594
terrykong merged 5 commits intomainfrom
youngeunk/generation_histogram

Conversation

@youngeunkwon0405
Copy link
Contributor

@youngeunkwon0405 youngeunkwon0405 commented Dec 3, 2025

What does this PR do ?

Sample wandb report with this PR: report

image

You can also download the histogram data from wandb.
image

Issues

List issues that this PR closes (syntax):

Closes #1493

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Added histogram visualization of generation metrics (input, output, and total token lengths) in monitoring dashboards
    • Extended token length tracking across multi-turn conversation scenarios for improved visibility
  • Improvements

    • Enhanced robustness of metric logging with error handling to prevent workflow interruptions and ensure continuous operation

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 self-assigned this Dec 3, 2025
@youngeunkwon0405 youngeunkwon0405 requested review from a team as code owners December 3, 2025 09:54
@youngeunkwon0405 youngeunkwon0405 added Performance Related to improving performance ease of use CI:L1 Run doctests, unit tests, and functional tests labels Dec 3, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 3, 2025

📝 Walkthrough

Walkthrough

This PR enhances observability of sequence length distributions during GRPO training by adding histogram-level logging to wandb. It introduces a new logging function for histogram metrics, integrates it into training loops, and extends token accounting to track per-turn input and total token lengths alongside existing generation metrics.

Changes

Cohort / File(s) Summary
Histogram logging infrastructure
nemo_rl/algorithms/utils.py
Renamed parameter vllm_logger_metrics to generation_logger_metrics in log_generation_metrics_to_wandb for consistency. Added new public function log_histogram_metrics_to_wandb that logs histogram metrics to wandb with guard for logger presence; accepts metric name, values, step, and logger instance.
Training integration
nemo_rl/algorithms/grpo.py
Imported log_histogram_metrics_to_wandb from utils. Added histogram logging calls in grpo_train and async_grpo_train to plot gen_tokens_lengths, input_tokens_lengths, and total_tokens_lengths after each training step; wrapped in try-except blocks to handle plotting errors without interrupting training.
Token accounting
nemo_rl/experience/rollouts.py
Extended per-turn token tracking to include turn_input_tokens and turn_total_tokens alongside existing turn_gen_tokens. Implemented metric aggregation using functools.reduce to compute per-sample token length sequences (gen_tokens_lengths, input_tokens_lengths, total_tokens_lengths) across multi-turn and single-sample rollout paths. Propagated new metrics through batch and per-sample rollout output.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Token accounting extensions in rollouts.py: Verify that functools.reduce aggregation correctly computes token length sequences across all rollout paths (multi-turn, single-sample multi-turn, worker-based).
  • Metric propagation consistency: Ensure new token-length metrics maintain consistent ordering and naming conventions throughout the rollout pipeline.
  • Error handling in histogram logging: Confirm try-except blocks in training loops appropriately suppress plotting errors without masking real issues.

Possibly related PRs

Suggested reviewers

  • parthchadha
  • terrykong

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major changes including new public function and parameter rename without corresponding test additions or modifications, and PR checklist items remain unchecked. Add unit and integration tests for new functionality, include performance metrics, and check all applicable PR checklist items.
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding histogram logging for generation ISL/OSL (input sequence length/output sequence length) metrics to wandb, which is the primary objective of this PR.
Linked Issues check ✅ Passed The PR successfully addresses issue #1493 by implementing full sequence length distribution plotting to wandb via histogram metrics for gen_tokens, input_tokens, and total_tokens.
Out of Scope Changes check ✅ Passed All changes are directly scoped to the linked issue: parameter renaming in utils.py improves clarity, new log_histogram_metrics_to_wandb function enables histogram logging, and rollouts.py tracks required metrics for distribution visualization.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch youngeunk/generation_histogram

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/experience/rollouts.py (1)

657-693: Convert input_lengths tensor to a Python scalar once per turn to avoid mixing tensor and int types in arithmetic and slicing

input_lengths is a torch.Tensor returned from async_generate_response_for_sample_turn, but it is handled inconsistently within the same loop iteration:

  • Line 689: turn_input_tokens.append(int(input_lengths)) — converts to Python int
  • Line 728: if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len: — uses tensor directly in arithmetic
  • Line 729: max_env_tokens = max_seq_len - input_lengths - gen_token_count — tensor arithmetic
  • Line 730: tokenized_obs = tokenized_obs[:max_env_tokens] — slices with result of tensor arithmetic

This fragile mixing relies on implicit PyTorch broadcasting and the __index__ protocol for slicing. Convert to a scalar once per turn and reuse for consistency:

             current_message_log = updated_message_log

+            # Convert tensor to Python scalar for this turn
+            input_length = int(input_lengths.item())
+
             # Update token counts
             gen_token_count = len(generated_tokens)
             assistant_token_count += gen_token_count
             token_count += gen_token_count
             turn_gen_tokens.append(gen_token_count)
-            turn_input_tokens.append(int(input_lengths))
-            turn_total_tokens.append(int(input_lengths) + gen_token_count)
+            turn_input_tokens.append(input_length)
+            turn_total_tokens.append(input_length + gen_token_count)
             # Per-worker load accounting

And at the overflow check:

-        if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len:
+        if input_length + gen_token_count + len(tokenized_obs) >= max_seq_len:
             # Truncate environment observation
-            max_env_tokens = max_seq_len - input_lengths - gen_token_count
+            max_env_tokens = max_seq_len - input_length - gen_token_count
🧹 Nitpick comments (2)
nemo_rl/experience/rollouts.py (1)

940-949: Flatten per‑sample turn length lists without reduce(lambda x, y: x + y, …)

The current aggregation:

rollout_metrics["gen_tokens_lengths"] = reduce(
    lambda x, y: x + y, [m["turn_gen_tokens"] for m in all_sample_metrics]
)

is correct given each turn_* is a list[int], but reduce with list concatenation is less readable and can be quadratic for long lists.

Consider a clearer, linear‑time flatten using itertools.chain.from_iterable:

+from itertools import chain
@@
-        rollout_metrics["gen_tokens_lengths"] = reduce(
-            lambda x, y: x + y, [m["turn_gen_tokens"] for m in all_sample_metrics]
-        )
-        rollout_metrics["input_tokens_lengths"] = reduce(
-            lambda x, y: x + y, [m["turn_input_tokens"] for m in all_sample_metrics]
-        )
-        rollout_metrics["total_tokens_lengths"] = reduce(
-            lambda x, y: x + y, [m["turn_total_tokens"] for m in all_sample_metrics]
-        )
+        rollout_metrics["gen_tokens_lengths"] = list(
+            chain.from_iterable(m["turn_gen_tokens"] for m in all_sample_metrics)
+        )
+        rollout_metrics["input_tokens_lengths"] = list(
+            chain.from_iterable(m["turn_input_tokens"] for m in all_sample_metrics)
+        )
+        rollout_metrics["total_tokens_lengths"] = list(
+            chain.from_iterable(m["turn_total_tokens"] for m in all_sample_metrics)
+        )

Functionally equivalent, but simpler and more efficient for large batches/turn counts.

nemo_rl/algorithms/grpo.py (1)

1570-1585: Make histogram logging more robust (avoid broad Exception and ensure metrics are sequences)

Two related points around the new ISL/OSL/ISL+OSL histogram logging:

  1. Blind except Exception and missing‑key handling

In both grpo_train and async_grpo_train you have:

# Plot ISL/OSL/ISL+OSL histograms to wandb
try:
    for hist_metrics in [
        "gen_tokens_lengths",
        "input_tokens_lengths",
        "total_tokens_lengths",
    ]:
        log_histogram_metrics_to_wandb(
            f"generation_metrics/{hist_metrics}",
            metrics[hist_metrics],
            ...,
        )
except Exception as e:
    print(f"❌ Error plotting histograms to wandb: {e}")
    pass

This pattern:

  • Relies on KeyError for missing metrics being swallowed by a broad except Exception, which is what Ruff BLE001 is flagging.
  • Also hides genuine wandb/logging bugs; training continues, but you only get a console print.

You can avoid both the broad catch and the missing‑key reliance by checking for the metric’s presence before logging:

-            # Plot ISL/OSL/ISL+OSL histograms to wandb
-            try:
-                for hist_metrics in [
-                    "gen_tokens_lengths",
-                    "input_tokens_lengths",
-                    "total_tokens_lengths",
-                ]:
-                    log_histogram_metrics_to_wandb(
-                        f"generation_metrics/{hist_metrics}",
-                        metrics[hist_metrics],
-                        total_steps + 1,
-                        logger,
-                    )
-            except Exception as e:
-                print(f"❌ Error plotting histograms to wandb: {e}")
-                pass
+            # Plot ISL/OSL/ISL+OSL histograms to wandb when available
+            for hist_metric in (
+                "gen_tokens_lengths",
+                "input_tokens_lengths",
+                "total_tokens_lengths",
+            ):
+                hist_values = metrics.get(hist_metric)
+                if hist_values is None:
+                    continue
+                log_histogram_metrics_to_wandb(
+                    f"generation_metrics/{hist_metric}",
+                    hist_values,
+                    total_steps + 1,
+                    logger,
+                )

and analogously in async_grpo_train (using step + 1 for the step). This satisfies the guideline to avoid blind except Exception and makes the control flow clearer.

  1. Async GRPO aggregation vs histogram expectations

In async_grpo_train, per‑trajectory rollout metrics are aggregated as:

rollout_metrics = {}
for t in trajectories:
    for k, v in t["rollout_metrics"].items():
        rollout_metrics.setdefault(k, []).append(v)
# ...
rollout_metrics = {
    k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v)
    for k, v in rollout_metrics.items()
}

If per‑trajectory ISL/OSL metrics are sequences (e.g., lists of token lengths per turn), the above will either:

  • Treat them as scalars and average (if v[0] is an int), giving a single float instead of a list, or
  • Leave you with a list‑of‑lists if each v is already a list.

Either way, this may not match what wandb.Histogram expects (a 1‑D list/array of values) when you later pass metrics["gen_tokens_lengths"] into log_histogram_metrics_to_wandb.

I recommend special‑casing these histogram metrics at aggregation time so they are flattened across trajectories instead of averaged:

@@
-            rollout_metrics = {
-                k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v)
-                for k, v in rollout_metrics.items()
-            }
+            rollout_metrics = {
+                k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v)
+                for k, v in rollout_metrics.items()
+            }
+
+            # Ensure histogram metrics are flattened sequences suitable for wandb.Histogram
+            for hist_key in (
+                "gen_tokens_lengths",
+                "input_tokens_lengths",
+                "total_tokens_lengths",
+            ):
+                if hist_key in rollout_metrics:
+                    values = rollout_metrics[hist_key]
+                    # If we accumulated per‑trajectory lists, flatten them
+                    if values and isinstance(values[0], (list, tuple)):
+                        rollout_metrics[hist_key] = [
+                            x for group in values for x in group
+                        ]

Please double‑check the exact structure of t["rollout_metrics"] in the async collector to confirm whether flattening like this is needed for your case; the important part is that by the time you call log_histogram_metrics_to_wandb, metrics[hist_key] should be a flat list/array of numeric values, not an average scalar or nested lists.

Also applies to: 2130-2138, 2510-2525

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between edd5e7a and 1e67087.

📒 Files selected for processing (3)
  • nemo_rl/algorithms/grpo.py (3 hunks)
  • nemo_rl/algorithms/utils.py (1 hunks)
  • nemo_rl/experience/rollouts.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/experience/rollouts.py
  • nemo_rl/algorithms/utils.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/experience/rollouts.py
  • nemo_rl/algorithms/utils.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/experience/rollouts.py
  • nemo_rl/algorithms/utils.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/algorithms/grpo.py
  • nemo_rl/experience/rollouts.py
  • nemo_rl/algorithms/utils.py
🧬 Code graph analysis (2)
nemo_rl/algorithms/grpo.py (1)
nemo_rl/algorithms/utils.py (1)
  • log_histogram_metrics_to_wandb (774-793)
nemo_rl/algorithms/utils.py (1)
nemo_rl/utils/logger.py (2)
  • Logger (805-1123)
  • log_plot_per_worker_timeline_metrics (939-1018)
🪛 Ruff (0.14.7)
nemo_rl/algorithms/grpo.py

1583-1583: Do not catch blind exception: Exception

(BLE001)


2523-2523: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: sphinx-build / Build docs
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (1)
nemo_rl/algorithms/utils.py (1)

750-772: Generation and histogram logging helpers look consistent with Logger usage

The refactored log_generation_metrics_to_wandb and new log_histogram_metrics_to_wandb both align with how Logger is structured:

  • generation_logger_metrics: dict[str, dict[int, list[Any]]] matches what log_plot_per_worker_timeline_metrics expects (inner dict is per‑worker series).
  • The generation_metrics/… prefix keeps naming consistent with the new histogram metric names in grpo.py.
  • log_histogram_metrics_to_wandb is correctly guarded on logger.wandb_logger and imports wandb lazily, so it’s safe to call unconditionally from training code.

No functional issues from this diff; the helpers are ready to be used by the GRPO training loops.

Also applies to: 774-793

@youngeunkwon0405
Copy link
Contributor Author

Hi @guyueh1, do you think this implementation meets your expectations?
Here is a sample wandb report.

@youngeunkwon0405
Copy link
Contributor Author

Hi @terrykong, can I ask for your review and merge for this PR?

Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 requested a review from a team as a code owner December 4, 2025 19:19
@youngeunkwon0405 youngeunkwon0405 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 4, 2025
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
@youngeunkwon0405 youngeunkwon0405 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 4, 2025
@youngeunkwon0405
Copy link
Contributor Author

Hi @terrykong, I have addressed most of your comments in the newer version. Could you please take a look?

@terrykong terrykong merged commit 140cd97 into main Dec 5, 2025
40 of 42 checks passed
@terrykong terrykong deleted the youngeunk/generation_histogram branch December 5, 2025 06:33
guyueh1 pushed a commit to guyueh1/NeMo-RL that referenced this pull request Dec 9, 2025
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
DeL-TaiseiOzaki pushed a commit to DeL-TaiseiOzaki/RL that referenced this pull request Jan 8, 2026
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 12, 2026
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 9, 2026
Signed-off-by: Youngeun Kwon <youngeunk@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests ease of use Performance Related to improving performance

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add more detailed sequence length distribution logging

2 participants