Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,5 @@ build-docs:
model-load-test:
@echo "--- 🚀 Running model load test ---"
pip install ".[dev, speedtask, pylate,gritlm,xformers,model2vec]"
python scripts/extract_model_names.py $(BASE_BRANCH)
python scripts/extract_model_names.py $(BASE_BRANCH) --return_one_model_name_per_file
python tests/test_models/model_loading.py --model_name_file scripts/model_names.txt
41 changes: 37 additions & 4 deletions scripts/extract_model_names.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import argparse
import ast
import sys
import logging
from pathlib import Path

from git import Repo

logging.basicConfig(level=logging.INFO)


def get_changed_files(base_branch="main"):
repo_path = Path(__file__).parent.parent
Expand All @@ -28,8 +31,11 @@ def get_changed_files(base_branch="main"):
]


def extract_model_names(files: list[str]) -> list[str]:
def extract_model_names(
files: list[str], return_one_model_name_per_file=False
) -> list[str]:
model_names = []
first_model_found = False
for file in files:
with open(file) as f:
tree = ast.parse(f.read())
Expand All @@ -52,17 +58,44 @@ def extract_model_names(files: list[str]) -> list[str]:
)
if model_name:
model_names.append(model_name)
first_model_found = True
if return_one_model_name_per_file and first_model_found:
logging.info(f"Found model name {model_name} in file {file}")
break # NOTE: Only take the first model_name per file to avoid disk out of space issue.
return model_names


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"base_branch",
nargs="?",
default="main",
help="Base branch to compare changes with",
)
parser.add_argument(
"--return_one_model_name_per_file",
action="store_true",
default=False,
help="Only return one model name per file.",
)
return parser.parse_args()


if __name__ == "__main__":
"""
Can pass in base branch as an argument. Defaults to 'main'.
e.g. python extract_model_names.py mieb
"""
base_branch = sys.argv[1] if len(sys.argv) > 1 else "main"

args = parse_args()

base_branch = args.base_branch
changed_files = get_changed_files(base_branch)
model_names = extract_model_names(changed_files)
model_names = extract_model_names(
changed_files,
return_one_model_name_per_file=args.return_one_model_name_per_file,
)
output_file = Path(__file__).parent / "model_names.txt"
with output_file.open("w") as f:
f.write(" ".join(model_names))