Skip to content

Further modify tis#3

Merged
zhaochenyang20 merged 2 commits intozhaochenyang20:importance_samplingfrom
guapisolo:tis
Oct 15, 2025
Merged

Further modify tis#3
zhaochenyang20 merged 2 commits intozhaochenyang20:importance_samplingfrom
guapisolo:tis

Conversation

@guapisolo
Copy link
Copy Markdown
Collaborator

Add kl metrics and helper function

@gemini-code-assist
Copy link
Copy Markdown

Summary of Changes

Hello @guapisolo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the train_infer_is utility by introducing new helper functions for masked tensor operations and integrating a dedicated function to compute and log a suite of perplexity and KL divergence metrics. These additions aim to provide more detailed insights into the behavior of training and rollout policies during importance sampling, facilitating better analysis and understanding of model performance.

Highlights

  • New Helper Functions: Introduced masked_sum and masked_mean helper functions to perform robust masked sum and mean calculations on PyTorch tensors, handling potential division by zero.
  • KL and Perplexity Metrics: Added a comprehensive add_ppl_metrics function that calculates and logs various perplexity (PPL) and Kullback-Leibler (KL) divergence metrics for both training and rollout policies, including direct KL, K3 KL estimator, and perplexity ratios.
  • Integration and Refactoring: The compute_train_infer_is_weights function was refactored to utilize the new masked_sum and masked_mean functions for aggregating log ratios and now integrates the add_ppl_metrics function to log detailed policy-related metrics.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces helper functions for masked sum and mean operations, refactors the importance sampling weight calculation to use these new helpers, and adds a new function to compute various KL-divergence and perplexity-related metrics. These changes improve code structure and add valuable monitoring capabilities. My review focuses on improving the new metrics function by addressing a logical error in metric calculation, removing a redundant operation, and adding a missing type hint for better code quality and correctness.

Comment on lines +311 to +312
# Since ppl = exp(-log_prob), we have:
# log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The metrics log_ppl_diff_max and log_ppl_diff_min are incorrectly calculated. You are appending the same log_ppl_diff tensor for both. Since log_ppl_diff contains identical values for a single sequence (due to expand=True), this will not compute the maximum or minimum across sequences. The aggregation framework will then compute the mean, resulting in identical values for log_ppl_diff, log_ppl_diff_max, and log_ppl_diff_min, which is not the intended behavior.

I recommend removing these two lines until a proper mechanism for computing and logging max/min values across the batch is implemented.

Comment on lines +274 to +279
def add_ppl_metrics(
train_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
loss_mask: torch.Tensor,
metrics: Dict[str, list[torch.Tensor]],
):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function add_ppl_metrics is missing a return type hint. According to PEP 484, functions should have type hints for arguments and return values. Since this function does not return a value, you should add -> None to its signature for better code clarity and static analysis.

Suggested change
def add_ppl_metrics(
train_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
loss_mask: torch.Tensor,
metrics: Dict[str, list[torch.Tensor]],
):
def add_ppl_metrics(
train_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
loss_mask: torch.Tensor,
metrics: Dict[str, list[torch.Tensor]],
) -> None:

loss_mask: torch.Tensor,
metrics: Dict[str, list[torch.Tensor]],
):
loss_mask = loss_mask.float()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loss_mask is already converted to a float tensor in the calling function compute_train_infer_is_weights at line 151. This conversion here is redundant. It's best to remove it to avoid unnecessary operations and rely on the caller to provide the correct data type. Consider adding a note to the function's docstring to clarify that loss_mask is expected to be a float tensor.

raw_log_ratio = train_log_prob - rollout_log_prob
loss_mask = loss_mask.float()
add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics)
raw_log_ratio = train_log_prob - rollout_log_prob
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the naming of raw_log_ratio, maybe to raw_log_ratio_diff


log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND)
weights = torch.exp(log_ratio_safe)
metrics_append(metrics, "ratio_mean_before_tis", weights)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean_is_weight_before_clip

@zhaochenyang20 zhaochenyang20 merged commit 1d35f45 into zhaochenyang20:importance_sampling Oct 15, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants