diff --git a/.github/workflows/dataset_loading_pr.yml b/.github/workflows/dataset_loading_pr.yml new file mode 100644 index 0000000000..c9de1d2300 --- /dev/null +++ b/.github/workflows/dataset_loading_pr.yml @@ -0,0 +1,28 @@ +name: Datasets available on HuggingFace - PR + +on: + pull_request: + paths: + - "mteb/tasks/**.py" + +jobs: + run-pr-datasets-loading-check: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + cache: 'pip' + + - name: Install dependencies + run: | + make install-for-tests + + - name: Run dataset loading tests + run: | + make dataset-load-test-pr BASE_BRANCH=${{ github.event.pull_request.base.ref }} diff --git a/Makefile b/Makefile index 11e0c85da1..e9c5de3308 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,10 @@ dataset-load-test: @echo "--- 🚀 Running dataset load test ---" pytest -m test_datasets +dataset-load-test-pr: + @echo "--- 🚀 Running dataset load test for PR ---" + eval "$$(python -m scripts.extract_datasets $(BASE_BRANCH))" && pytest -m test_datasets + leaderboard-build-test: @echo "--- 🚀 Running leaderboard build test ---" pytest -n auto -m leaderboard_stability diff --git a/mteb/tasks/Image/ZeroShotClassification/eng/Birdsnap.py b/mteb/tasks/Image/ZeroShotClassification/eng/Birdsnap.py index 280c2f2ee5..0fd2abcb36 100644 --- a/mteb/tasks/Image/ZeroShotClassification/eng/Birdsnap.py +++ b/mteb/tasks/Image/ZeroShotClassification/eng/Birdsnap.py @@ -9,7 +9,7 @@ class BirdsnapZeroShotClassification(AbsTaskZeroShotClassification): metadata = TaskMetadata( name="BirdsnapZeroShot", - description="Classifying bird images from 500 species.", + description="Classifying bird images from 500 species. ", reference="https://openaccess.thecvf.com/content_cvpr_2014/html/Berg_Birdsnap_Large-scale_Fine-grained_2014_CVPR_paper.html", dataset={ "path": "isaacchung/birdsnap", diff --git a/pyproject.toml b/pyproject.toml index 4a80e8607d..a4a63d2c20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ test = [ "pytest-coverage>=0.0", "pytest-rerunfailures>=15.0,<16.0", "iso639>=0.1.4", # used for tests/scripts/test_generate_model_meta.py + "GitPython>=3.0.0", ] dev = [ {include-group = "lint"}, diff --git a/scripts/extract_datasets.py b/scripts/extract_datasets.py new file mode 100644 index 0000000000..c65a0ba153 --- /dev/null +++ b/scripts/extract_datasets.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import argparse +import ast +import logging +import os + +from .extract_model_names import get_changed_files + +logging.basicConfig(level=logging.INFO) + + +def extract_datasets(files: list[str]) -> list[tuple[str, str]]: + """Extract dataset (path, revision) tuples from task class files.""" + datasets = [] + + for file in files: + with open(file) as f: + try: + tree = ast.parse(f.read()) + for node in ast.walk(tree): + # Look for class definitions (task classes) + if isinstance(node, ast.ClassDef): + # Check if it's a task class by looking for TaskMetadata assignment + for class_node in ast.walk(node): + if isinstance(class_node, ast.Assign): + for target in class_node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "metadata" + and isinstance(class_node.value, ast.Call) + ): + # Extract dataset info from TaskMetadata + dataset_info = extract_dataset_from_metadata( + class_node.value + ) + if dataset_info: + datasets.append(dataset_info) + + # Also look for direct dataset dictionary assignments + elif isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "dataset" + and isinstance(node.value, ast.Dict) + ): + dataset_info = extract_dataset_from_dict(node.value) + if dataset_info: + datasets.append(dataset_info) + + except SyntaxError as e: + logging.warning(f"Could not parse {file}: {e}") + continue + + # Remove duplicates while preserving order + unique_datasets = list(dict.fromkeys(datasets)) + + # Set environment variable in format "path1:revision1,path2:revision2,..." + if unique_datasets: + custom_revisions = ",".join( + f"{path}:{revision}" for path, revision in unique_datasets + ) + os.environ["CUSTOM_DATASET_REVISIONS"] = custom_revisions + logging.debug(f"Set CUSTOM_DATASET_REVISIONS={custom_revisions}") + + print(f'export CUSTOM_DATASET_REVISIONS="{custom_revisions}"') + return unique_datasets + + +def extract_dataset_from_metadata(call_node: ast.Call) -> tuple[str, str] | None: + """Extract dataset info from TaskMetadata call.""" + for keyword in call_node.keywords: + if keyword.arg == "dataset" and isinstance(keyword.value, ast.Dict): + return extract_dataset_from_dict(keyword.value) + return None + + +def extract_dataset_from_dict(dict_node: ast.Dict) -> tuple[str, str] | None: + """Extract path and revision from a dataset dictionary.""" + path = None + revision = None + + for key, value in zip(dict_node.keys, dict_node.values): + if isinstance(key, ast.Constant) and key.value == "path": + if isinstance(value, ast.Constant): + path = value.value + elif isinstance(key, ast.Constant) and key.value == "revision": + if isinstance(value, ast.Constant): + revision = value.value + # Handle older Python versions with ast.Str + elif isinstance(key, ast.Str) and key.s == "path": + if isinstance(value, ast.Str): + path = value.s + elif isinstance(key, ast.Str) and key.s == "revision": + if isinstance(value, ast.Str): + revision = value.s + + if path and revision: + return (path, revision) + return None + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "base_branch", + nargs="?", + default="main", + help="Base branch to compare changes with", + ) + return parser.parse_args() + + +if __name__ == "__main__": + """ + Extract datasets from changed task files compared to a base branch.i + + Can pass in base branch as an argument. Defaults to 'main'. + e.g. python -m scripts.extract_datasets mieb + """ + logging.basicConfig(level=logging.INFO) + + args = parse_args() + + base_branch = args.base_branch + changed_files = get_changed_files(base_branch, startswith="mteb/tasks/") + dataset_tuples = extract_datasets(changed_files) + + logging.debug(f"Found {len(dataset_tuples)} unique datasets.") diff --git a/scripts/extract_model_names.py b/scripts/extract_model_names.py index 4d8c5d96f0..18115fdc4a 100644 --- a/scripts/extract_model_names.py +++ b/scripts/extract_model_names.py @@ -10,7 +10,7 @@ logging.basicConfig(level=logging.INFO) -def get_changed_files(base_branch="main"): +def get_changed_files(base_branch="main", startswith="mteb/models/") -> list[str]: repo_path = Path(__file__).parent.parent repo = Repo(repo_path) repo.remotes.origin.fetch(base_branch) @@ -24,7 +24,7 @@ def get_changed_files(base_branch="main"): return [ f for f in changed_files - if f.startswith("mteb/models/") + if f.startswith(startswith) and f.endswith(".py") and "overview" not in f and "init" not in f diff --git a/tests/test_tasks/test_all_abstasks.py b/tests/test_tasks/test_all_abstasks.py index 3f59e0a8cb..c76e05df8a 100644 --- a/tests/test_tasks/test_all_abstasks.py +++ b/tests/test_tasks/test_all_abstasks.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from unittest.mock import Mock, patch import huggingface_hub @@ -37,7 +38,7 @@ ] -dataset_revisions = list( +_original_dataset_revisions = list( { # deduplicate as multiple tasks rely on the same dataset (save us at least 100 test cases) (t.metadata.dataset["path"], t.metadata.dataset["revision"]) for t in mteb.get_tasks(exclude_superseded=False) @@ -47,6 +48,15 @@ } ) +custom_revisions = os.getenv("CUSTOM_DATASET_REVISIONS") +if custom_revisions: + # Parse comma-separated list of "path:revision" pairs + dataset_revisions = [ + tuple(pair.split(":", 1)) for pair in custom_revisions.split(",") if ":" in pair + ] +else: + dataset_revisions = _original_dataset_revisions + @pytest.mark.parametrize("task", tasks) @patch("datasets.load_dataset")