Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 61 additions & 42 deletions scripts/format_citations.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from __future__ import annotations

import argparse
import ast
import logging
from pathlib import Path

import bibtexparser
import typer
from bibtexparser.bwriter import BibTexWriter

app = typer.Typer()

logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
Expand Down Expand Up @@ -106,6 +104,7 @@ def format_bibtex(bibtex_str: str) -> str | None:
try:
bib_database = bibtexparser.loads(bibtex_str, parser=parser)
if not bib_database.entries:
logger.warning(f"No entries found in BibTeX string. {bibtex_str}")
return None
bib_database.comments = []

Expand All @@ -115,7 +114,8 @@ def format_bibtex(bibtex_str: str) -> str | None:
writer.add_trailing_comma = True

return writer.write(bib_database).strip()
except Exception:
except Exception as e:
logger.warning(f"Failed to parse BibTeX: {e}")
return None


Expand Down Expand Up @@ -234,28 +234,16 @@ def process_file(
)


@app.command()
def tasks(
tasks_dir: Path = typer.Argument(
Path("mteb/tasks"),
exists=True,
file_okay=False,
dir_okay=True,
readable=True,
help="Directory containing MTEB task Python files.",
),
dry_run: bool = typer.Option(
True,
"--dry-run",
help="Perform parsing and formatting but do not modify files.",
),
):
def tasks(args):
tasks_dir = Path(args.tasks_dir)
dry_run = args.dry_run

modified_files = error_files = skipped_files = processed_files = bibtex_modified = 0
task_files = sorted(tasks_dir.rglob("*.py"))

if not task_files:
logger.error(f"No Python files found in {tasks_dir}")
raise typer.Exit(code=1)
raise RuntimeError

logger.info(f"Found {len(task_files)} Python files in {tasks_dir}. Processing...")

Expand Down Expand Up @@ -288,25 +276,13 @@ def tasks(

if error_files > 0:
logger.warning("Errors occurred during processing. Check logs above.")
raise typer.Exit(code=1)
raise RuntimeError


@app.command()
def benchmarks(
benchmarks_file: Path = typer.Argument(
Path("mteb/benchmarks/benchmarks.py"),
exists=True,
file_okay=True,
dir_okay=False,
readable=True,
help="Path to the benchmarks.py file.",
),
dry_run: bool = typer.Option(
True,
"--dry-run",
help="Perform parsing and formatting but do not modify the file.",
),
):
def benchmarks(args):
benchmarks_file = Path(args.benchmarks_file)
dry_run = args.dry_run

logger.info(f"Processing {benchmarks_file}...")

file_modified, file_error, num_modified, no_keyword, no_locations = process_file(
Expand All @@ -315,12 +291,12 @@ def benchmarks(

if no_keyword:
logger.info(f"SKIPPED: No 'citation' keyword found in {benchmarks_file.name}.")
raise typer.Exit()
return
if no_locations:
logger.info(
f"SKIPPED: 'citation' keyword found, but no valid string literals detected in {benchmarks_file.name}."
)
raise typer.Exit()
return

logger.info("\n--- Summary ---")
logger.info(f"Processed File: {benchmarks_file.name}")
Expand All @@ -333,10 +309,53 @@ def benchmarks(

if file_error:
logger.warning("Errors occurred during processing. Check logs above.")
raise typer.Exit(code=1)
return
elif not file_modified and not file_error:
logger.info("No changes needed.")


def main():
parser = argparse.ArgumentParser(
description="Refactor script to use argparse instead of typer."
)
subparsers = parser.add_subparsers()

tasks_parser = subparsers.add_parser("tasks", help="Process tasks directory")
tasks_parser.add_argument(
"--tasks_dir",
type=str,
default=str(Path("mteb/tasks")),
help="Directory containing MTEB task Python files.",
)
tasks_parser.add_argument(
"--dry-run",
action="store_true",
help="Perform parsing and formatting but do not modify files.",
)
tasks_parser.set_defaults(func=tasks)

benchmarks_parser = subparsers.add_parser(
"benchmarks", help="Process benchmarks file"
)
benchmarks_parser.add_argument(
"--benchmarks_file",
type=str,
default=str(Path("mteb/benchmarks/benchmarks.py")),
help="Path to the benchmarks.py file.",
)
benchmarks_parser.add_argument(
"--dry-run",
action="store_false",
help="Perform parsing and formatting but do not modify the file.",
)
benchmarks_parser.set_defaults(func=benchmarks)

args = parser.parse_args()
if hasattr(args, "func"):
args.func(args)
else:
parser.print_help()


if __name__ == "__main__":
app()
main()