Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions .github/workflows/model-results-comparison.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: Model Results Comparison

on:
pull_request:
types: [opened, synchronize, edited]
paths:
- 'results/**/*.json'
workflow_dispatch:
inputs:
reference_models:
description: 'Space-separated list of reference models for comparison'
required: true
type: string
default: 'intfloat/multilingual-e5-large google/gemini-embedding-001'
pull_request_number:
description: 'The pull request number to comment on (required if triggered manually)'
required: false # Make it not strictly required if you want to run it without a PR context for other reasons
type: string

permissions:
contents: read
pull-requests: write

jobs:
compare-results:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Fetch origin main
run: git fetch origin main

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install dependencies
run: |
pip install git+https://github.com/embeddings-benchmark/mteb.git tabulate

- name: Generate model comparison
env:
REFERENCE_MODELS: ${{ github.event.inputs.reference_models || 'intfloat/multilingual-e5-large google/gemini-embedding-001' }}
run: |
python scripts/create_pr_results_comment.py --reference-models $REFERENCE_MODELS --output model-comparison.md

- name: Determine PR Number
id: pr_info
run: |
if [ "${{ github.event_name }}" == "pull_request" ]; then
echo "pr_number=${{ github.event.number }}" >> $GITHUB_OUTPUT
elif [ "${{ github.event_name }}" == "workflow_dispatch" ] && [ -n "${{ github.event.inputs.pull_request_number }}" ]; then
echo "pr_number=${{ github.event.inputs.pull_request_number }}" >> $GITHUB_OUTPUT
else
echo "pr_number=" >> $GITHUB_OUTPUT
fi

- name: Post PR comment
# This step will run if a PR number is available either from the PR event or workflow_dispatch input
if: steps.pr_info.outputs.pr_number != ''
env:
GITHUB_TOKEN: ${{ github.token }}
run: gh pr comment ${{ steps.pr_info.outputs.pr_number }} --body-file model-comparison.md --create-if-none --edit-last

- name: Upload comparison report
uses: actions/upload-artifact@v4
with:
name: model-comparison
path: model-comparison.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@
},
"evaluation_time": 47.84240365028381,
"kg_co2_emissions": null
}
}
166 changes: 113 additions & 53 deletions scripts/create_pr_results_comment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@

Usage:
gh pr checkout {pr-number}
scripts/create_pr_results_comment.py [--models MODEL1 MODEL2 ...]
python scripts/create_pr_results_comment.py [--models MODEL1 MODEL2 ...] [--output OUTPUT_FILE]

Description:
- Compares new model results (added in the current PR) against reference models.
- Outputs a Markdown table with results for each new model and highlights the best scores.
- Outputs a Markdown file with results for each new model and highlights the best scores.
- By default, compares against: intfloat/multilingual-e5-large and google/gemini-embedding-001.
- You can specify reference models with the --models argument.

Arguments:
--models: List of reference models to compare against (default: intfloat/multilingual-e5-large google/gemini-embedding-001)
--reference-models: List of reference models to compare against (default: intfloat/multilingual-e5-large google/gemini-embedding-001)
--output: Output markdown file path (default: model-comparison.md)

Example:
scripts/create_pr_results_comment.py --models intfloat/multilingual-e5-large myorg/my-new-model
python scripts/create_pr_results_comment.py --models intfloat/multilingual-e5-large myorg/my-new-model
"""

from __future__ import annotations
Expand All @@ -24,6 +25,7 @@
import json
import os
import subprocess
import logging
from collections import defaultdict
from pathlib import Path

Expand All @@ -32,33 +34,23 @@

TaskName, ModelName = str, str


repo_path = Path(__file__).parents[1]
results_path = repo_path / "results"

os.environ["MTEB_CACHE"] = str(repo_path.parent)


default_reference_models = [
# Default reference models to compare against
REFERENCE_MODELS: list[str] = [
"intfloat/multilingual-e5-large",
"google/gemini-embedding-001",
]

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

def get_diff_from_main() -> list[str]:
current_rev, origin_rev = subprocess.run(
["git", "rev-parse", "main", "origin/main"],
cwd=repo_path,
capture_output=True,
check=True,
text=True,
).stdout.splitlines()
repo_path = Path(__file__).parents[1]

os.environ["MTEB_CACHE"] = str(repo_path.parent)

if current_rev != origin_rev:
raise ValueError(
f"Your main branch is not up-to-date ({current_rev} != {origin_rev}), please run `git fetch origin main`"
)

def get_diff_from_main() -> list[str]:
differences = subprocess.run(
["git", "diff", "--name-only", "origin/main...HEAD"],
cwd=repo_path,
Expand Down Expand Up @@ -91,66 +83,134 @@ def extract_new_models_and_tasks(
return models


def create_comparison_table(models: list[str], tasks: list[str]) -> pd.DataFrame:
def create_comparison_table(
model: str, tasks: list[str], reference_models: list[str]
) -> pd.DataFrame:
models = [model] + reference_models
max_col_name = "Max result"
task_col_name = "task_name"
results = mteb.load_results(models=models, tasks=tasks, download_latest=False)

results = results.join_revisions()
df = results.to_dataframe()

# compute average pr. columns
model_names = [c for c in df.columns if c != "task_name"]
if df.empty:
raise ValueError(f"No results found for models {models} on tasks {tasks}")

row = pd.DataFrame(
df[max_col_name] = None
task_results = mteb.load_results(tasks=tasks, download_latest=False)
task_results = task_results.join_revisions()
max_dataframe = (
task_results.to_dataframe(format="long").groupby(task_col_name).max()
)
if not max_dataframe.empty:
for task_name, row in max_dataframe.iterrows():
df.loc[df[task_col_name] == task_name, max_col_name] = (
row["score"] / 100
) # scores are in percentage

averages: dict[str, float | None] = {}
for col in models + [max_col_name]:
numeric = pd.to_numeric(df[col], errors="coerce")
avg = numeric.mean()
averages[col] = avg if not pd.isna(avg) else None

avg_row = pd.DataFrame(
{
"task_name": ["**Average**"],
**{
model: df[model].mean() if model != "task_name" else None
for model in model_names
},
task_col_name: ["**Average**"],
**{col: [val] for col, val in averages.items()},
}
)
df = pd.concat([df, row], ignore_index=True)
return df
return pd.concat([df, avg_row], ignore_index=True)


def highlight_max_bold(df, exclude_cols=["task_name"]):
# result_df = df.copy().astype(str)
# only 2 decimal places except for the excluded columns
def highlight_max_bold(
df: pd.DataFrame, exclude_cols: list[str] = ["task_name"]
) -> pd.DataFrame:
result_df = df.copy()
result_df = result_df.applymap(lambda x: f"{x:.2f}" if isinstance(x, float) else x)
tmp_df = df.copy()
tmp_df = tmp_df.drop(columns=exclude_cols)
for col in result_df.columns:
if col not in exclude_cols:
result_df[col] = result_df[col].apply(
lambda x: f"{x:.2f}"
if isinstance(x, (int, float)) and pd.notna(x)
else x
)

tmp = df.drop(columns=exclude_cols)
for idx in df.index:
max_col = tmp_df.loc[idx].idxmax()
result_df.loc[idx, max_col] = f"**{result_df.loc[idx, max_col]}**"
row = pd.to_numeric(tmp.loc[idx], errors="coerce")
if row.isna().all():
continue
max_col = row.idxmax()
if pd.notna(row[max_col]):
result_df.at[idx, max_col] = f"**{result_df.at[idx, max_col]}**"

return result_df


def generate_markdown_content(
model_tasks: dict[str, list[str]], reference_models: list[str]
) -> str:
if not model_tasks:
return "# Model Results Comparison\n\nNo new model results found in this PR."

all_tasks = sorted({t for tasks in model_tasks.values() for t in tasks})
new_models = list(model_tasks.keys())

parts: list[str] = [
"# Model Results Comparison",
"",
f"**Reference models:** {', '.join(f'`{m}`' for m in reference_models)}",
f"**New models evaluated:** {', '.join(f'`{m}`' for m in new_models)}",
f"**Tasks:** {', '.join(f'`{t}`' for t in all_tasks)}",
"",
]

for model_name, tasks in model_tasks.items():
parts.append(f"## Results for `{model_name}`")

df = create_comparison_table(model_name, tasks, reference_models)
bold_df = highlight_max_bold(df)
parts.append(bold_df.to_markdown(index=False))

parts.extend(["", "---", ""])

return "\n".join(parts)


def create_argparse() -> argparse.ArgumentParser:
"""Create the argument parser for the script."""
parser = argparse.ArgumentParser(
description="Create PR comment with results comparison."
)
parser.add_argument(
"--models",
"--reference-models",
nargs="+",
default=default_reference_models,
default=REFERENCE_MODELS,
help="List of reference models to compare against (default: %(default)s)",
)
parser.add_argument(
"--output",
type=Path,
default=Path("model-comparison.md"),
help="Output markdown file path",
)
return parser


def main(reference_models: list[str]):
def main(reference_models: list[str], output_path: Path) -> None:
logger.info("Starting to create PR results comment...")
logger.info(f"Using reference models: {', '.join(reference_models)}")
diff = get_diff_from_main()
new_additions = extract_new_models_and_tasks(diff)

for model, tasks in new_additions.items():
print(f"**Results for `{model}`**")
df = create_comparison_table(models=reference_models + [model], tasks=tasks)
bold_df = highlight_max_bold(df)
print(bold_df.to_markdown(index=False))
model_tasks = extract_new_models_and_tasks(diff)
markdown = generate_markdown_content(model_tasks, reference_models)

output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(markdown)


if __name__ == "__main__":
parser = create_argparse()
args = parser.parse_args()
main(reference_models=args.models)
main(args.reference_models, args.output)
Loading