diff --git a/analysis/plot_results.py b/analysis/plot_results.py index ae885e8..c265863 100644 --- a/analysis/plot_results.py +++ b/analysis/plot_results.py @@ -95,6 +95,9 @@ def get_args(): parser_translate.add_argument("--gtrans", type=Path, required=True, help="Path to the Google Translate results file.") parser_translate.add_argument("--nllb", type=Path, required=True, help="Path to the NLLB-3.3B results file.") + parser_histogram = subparsers.add_parser("reward_histogram", help="Plot reward histogram", parents=[shared_args]) + parser_histogram.add_argument("--input_path", type=Path, required=True, help="Path to the input file containing raw rewards.") + # fmt: on return parser.parse_args() @@ -107,6 +110,7 @@ def main(): "eng_drop_line": plot_eng_drop_line, "ling_dims": plot_ling_dims, "translate": plot_translate, + "reward_histogram": plot_reward_histogram, } def _filter_args(func, kwargs): @@ -394,5 +398,38 @@ def plot_translate( fig.savefig(output_path, bbox_inches="tight") +def plot_reward_histogram( + input_path: Path, + output_path: Path, + figsize: Optional[tuple[int, int]] = (18, 5), +): + df = pd.read_json(input_path, lines=True) + fig, ax = plt.subplots(1, 1, figsize=figsize) + bins = 30 + alpha = 0.7 + df["chosen"].hist( + ax=ax, + bins=bins, + grid=False, + label="Chosen", + alpha=alpha, + edgecolor=COLORS.get("green"), + color=COLORS.get("green"), + ) + df["rejected"].hist( + ax=ax, + bins=bins, + grid=False, + label="Rejected", + alpha=alpha, + edgecolor=COLORS.get("orange"), + color=COLORS.get("orange"), + ) + ax.legend(frameon=False) + + plt.tight_layout() + fig.savefig(output_path, bbox_inches="tight") + + if __name__ == "__main__": main()