Skip to content

Commit

Permalink
Add code for getting the histogram (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljvmiranda921 authored Oct 13, 2024
1 parent 64c72c0 commit bf54993
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions analysis/plot_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def get_args():
parser_translate.add_argument("--gtrans", type=Path, required=True, help="Path to the Google Translate results file.")
parser_translate.add_argument("--nllb", type=Path, required=True, help="Path to the NLLB-3.3B results file.")

parser_histogram = subparsers.add_parser("reward_histogram", help="Plot reward histogram", parents=[shared_args])
parser_histogram.add_argument("--input_path", type=Path, required=True, help="Path to the input file containing raw rewards.")

# fmt: on
return parser.parse_args()

Expand All @@ -107,6 +110,7 @@ def main():
"eng_drop_line": plot_eng_drop_line,
"ling_dims": plot_ling_dims,
"translate": plot_translate,
"reward_histogram": plot_reward_histogram,
}

def _filter_args(func, kwargs):
Expand Down Expand Up @@ -394,5 +398,38 @@ def plot_translate(
fig.savefig(output_path, bbox_inches="tight")


def plot_reward_histogram(
input_path: Path,
output_path: Path,
figsize: Optional[tuple[int, int]] = (18, 5),
):
df = pd.read_json(input_path, lines=True)
fig, ax = plt.subplots(1, 1, figsize=figsize)
bins = 30
alpha = 0.7
df["chosen"].hist(
ax=ax,
bins=bins,
grid=False,
label="Chosen",
alpha=alpha,
edgecolor=COLORS.get("green"),
color=COLORS.get("green"),
)
df["rejected"].hist(
ax=ax,
bins=bins,
grid=False,
label="Rejected",
alpha=alpha,
edgecolor=COLORS.get("orange"),
color=COLORS.get("orange"),
)
ax.legend(frameon=False)

plt.tight_layout()
fig.savefig(output_path, bbox_inches="tight")


if __name__ == "__main__":
main()

0 comments on commit bf54993

Please sign in to comment.