Further modify tis#3
Conversation
Summary of ChangesHello @guapisolo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces helper functions for masked sum and mean operations, refactors the importance sampling weight calculation to use these new helpers, and adds a new function to compute various KL-divergence and perplexity-related metrics. These changes improve code structure and add valuable monitoring capabilities. My review focuses on improving the new metrics function by addressing a logical error in metric calculation, removing a redundant operation, and adding a missing type hint for better code quality and correctness.
| # Since ppl = exp(-log_prob), we have: | ||
| # log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff |
There was a problem hiding this comment.
The metrics log_ppl_diff_max and log_ppl_diff_min are incorrectly calculated. You are appending the same log_ppl_diff tensor for both. Since log_ppl_diff contains identical values for a single sequence (due to expand=True), this will not compute the maximum or minimum across sequences. The aggregation framework will then compute the mean, resulting in identical values for log_ppl_diff, log_ppl_diff_max, and log_ppl_diff_min, which is not the intended behavior.
I recommend removing these two lines until a proper mechanism for computing and logging max/min values across the batch is implemented.
| def add_ppl_metrics( | ||
| train_log_prob: torch.Tensor, | ||
| rollout_log_prob: torch.Tensor, | ||
| loss_mask: torch.Tensor, | ||
| metrics: Dict[str, list[torch.Tensor]], | ||
| ): |
There was a problem hiding this comment.
The function add_ppl_metrics is missing a return type hint. According to PEP 484, functions should have type hints for arguments and return values. Since this function does not return a value, you should add -> None to its signature for better code clarity and static analysis.
| def add_ppl_metrics( | |
| train_log_prob: torch.Tensor, | |
| rollout_log_prob: torch.Tensor, | |
| loss_mask: torch.Tensor, | |
| metrics: Dict[str, list[torch.Tensor]], | |
| ): | |
| def add_ppl_metrics( | |
| train_log_prob: torch.Tensor, | |
| rollout_log_prob: torch.Tensor, | |
| loss_mask: torch.Tensor, | |
| metrics: Dict[str, list[torch.Tensor]], | |
| ) -> None: |
| loss_mask: torch.Tensor, | ||
| metrics: Dict[str, list[torch.Tensor]], | ||
| ): | ||
| loss_mask = loss_mask.float() |
There was a problem hiding this comment.
The loss_mask is already converted to a float tensor in the calling function compute_train_infer_is_weights at line 151. This conversion here is redundant. It's best to remove it to avoid unnecessary operations and rely on the caller to provide the correct data type. Consider adding a note to the function's docstring to clarify that loss_mask is expected to be a float tensor.
| raw_log_ratio = train_log_prob - rollout_log_prob | ||
| loss_mask = loss_mask.float() | ||
| add_ppl_metrics(train_log_prob, rollout_log_prob, loss_mask, metrics) | ||
| raw_log_ratio = train_log_prob - rollout_log_prob |
There was a problem hiding this comment.
change the naming of raw_log_ratio, maybe to raw_log_ratio_diff
|
|
||
| log_ratio_safe = torch.clamp(log_ratio_for_metrics, min=-SAFETY_BOUND, max=SAFETY_BOUND) | ||
| weights = torch.exp(log_ratio_safe) | ||
| metrics_append(metrics, "ratio_mean_before_tis", weights) |
There was a problem hiding this comment.
mean_is_weight_before_clip
1d35f45
into
zhaochenyang20:importance_sampling
Add kl metrics and helper function