Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/dataset_loading_pr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Datasets available on HuggingFace - PR

on:
pull_request:
paths:
- "mteb/tasks/**.py"

jobs:
run-pr-datasets-loading-check:
runs-on: ubuntu-latest

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'

- name: Install dependencies
run: |
make install-for-tests
- name: Run dataset loading tests
run: |
make dataset-load-test-pr BASE_BRANCH=${{ github.event.pull_request.base.ref }}
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ dataset-load-test:
@echo "--- 🚀 Running dataset load test ---"
pytest -m test_datasets

dataset-load-test-pr:
@echo "--- 🚀 Running dataset load test for PR ---"
eval "$$(python -m scripts.extract_datasets $(BASE_BRANCH))" && pytest -m test_datasets

leaderboard-build-test:
@echo "--- 🚀 Running leaderboard build test ---"
pytest -n auto -m leaderboard_stability
Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Image/ZeroShotClassification/eng/Birdsnap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class BirdsnapZeroShotClassification(AbsTaskZeroShotClassification):
metadata = TaskMetadata(
name="BirdsnapZeroShot",
description="Classifying bird images from 500 species.",
description="Classifying bird images from 500 species. ",
reference="https://openaccess.thecvf.com/content_cvpr_2014/html/Berg_Birdsnap_Large-scale_Fine-grained_2014_CVPR_paper.html",
dataset={
"path": "isaacchung/birdsnap",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ test = [
"pytest-coverage>=0.0",
"pytest-rerunfailures>=15.0,<16.0",
"iso639>=0.1.4", # used for tests/scripts/test_generate_model_meta.py
"GitPython>=3.0.0",
]
dev = [
{include-group = "lint"},
Expand Down
130 changes: 130 additions & 0 deletions scripts/extract_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import argparse
import ast
import logging
import os

from .extract_model_names import get_changed_files

logging.basicConfig(level=logging.INFO)


def extract_datasets(files: list[str]) -> list[tuple[str, str]]:
"""Extract dataset (path, revision) tuples from task class files."""
datasets = []

for file in files:
with open(file) as f:
try:
tree = ast.parse(f.read())
for node in ast.walk(tree):
# Look for class definitions (task classes)
if isinstance(node, ast.ClassDef):
# Check if it's a task class by looking for TaskMetadata assignment
for class_node in ast.walk(node):
if isinstance(class_node, ast.Assign):
for target in class_node.targets:
if (
isinstance(target, ast.Name)
and target.id == "metadata"
and isinstance(class_node.value, ast.Call)
):
# Extract dataset info from TaskMetadata
dataset_info = extract_dataset_from_metadata(
class_node.value
)
if dataset_info:
datasets.append(dataset_info)

# Also look for direct dataset dictionary assignments
elif isinstance(node, ast.Assign):
for target in node.targets:
if (
isinstance(target, ast.Name)
and target.id == "dataset"
and isinstance(node.value, ast.Dict)
):
dataset_info = extract_dataset_from_dict(node.value)
if dataset_info:
datasets.append(dataset_info)

except SyntaxError as e:
logging.warning(f"Could not parse {file}: {e}")
continue

# Remove duplicates while preserving order
unique_datasets = list(dict.fromkeys(datasets))

# Set environment variable in format "path1:revision1,path2:revision2,..."
if unique_datasets:
custom_revisions = ",".join(
f"{path}:{revision}" for path, revision in unique_datasets
)
os.environ["CUSTOM_DATASET_REVISIONS"] = custom_revisions
logging.debug(f"Set CUSTOM_DATASET_REVISIONS={custom_revisions}")

print(f'export CUSTOM_DATASET_REVISIONS="{custom_revisions}"')
return unique_datasets


def extract_dataset_from_metadata(call_node: ast.Call) -> tuple[str, str] | None:
"""Extract dataset info from TaskMetadata call."""
for keyword in call_node.keywords:
if keyword.arg == "dataset" and isinstance(keyword.value, ast.Dict):
return extract_dataset_from_dict(keyword.value)
return None


def extract_dataset_from_dict(dict_node: ast.Dict) -> tuple[str, str] | None:
"""Extract path and revision from a dataset dictionary."""
path = None
revision = None

for key, value in zip(dict_node.keys, dict_node.values):
if isinstance(key, ast.Constant) and key.value == "path":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work?

Suggested change
if isinstance(key, ast.Constant) and key.value == "path":
if isinstance(key, (ast.Constant, ast.Str)) and key.value == "path":

if isinstance(value, ast.Constant):
path = value.value
Comment on lines +86 to +87
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe be move to parent if. Do we have cases when value is not a constant/string?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can take a look :)

elif isinstance(key, ast.Constant) and key.value == "revision":
if isinstance(value, ast.Constant):
revision = value.value
# Handle older Python versions with ast.Str
elif isinstance(key, ast.Str) and key.s == "path":
if isinstance(value, ast.Str):
path = value.s
elif isinstance(key, ast.Str) and key.s == "revision":
if isinstance(value, ast.Str):
revision = value.s

if path and revision:
return (path, revision)
return None


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"base_branch",
nargs="?",
default="main",
help="Base branch to compare changes with",
)
return parser.parse_args()


if __name__ == "__main__":
"""
Extract datasets from changed task files compared to a base branch.i

Can pass in base branch as an argument. Defaults to 'main'.
e.g. python -m scripts.extract_datasets mieb
"""
logging.basicConfig(level=logging.INFO)

args = parse_args()

base_branch = args.base_branch
changed_files = get_changed_files(base_branch, startswith="mteb/tasks/")
dataset_tuples = extract_datasets(changed_files)

logging.debug(f"Found {len(dataset_tuples)} unique datasets.")
4 changes: 2 additions & 2 deletions scripts/extract_model_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logging.basicConfig(level=logging.INFO)


def get_changed_files(base_branch="main"):
def get_changed_files(base_branch="main", startswith="mteb/models/") -> list[str]:
repo_path = Path(__file__).parent.parent
repo = Repo(repo_path)
repo.remotes.origin.fetch(base_branch)
Expand All @@ -24,7 +24,7 @@ def get_changed_files(base_branch="main"):
return [
f
for f in changed_files
if f.startswith("mteb/models/")
if f.startswith(startswith)
and f.endswith(".py")
and "overview" not in f
and "init" not in f
Expand Down
12 changes: 11 additions & 1 deletion tests/test_tasks/test_all_abstasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
from unittest.mock import Mock, patch

import huggingface_hub
Expand Down Expand Up @@ -37,7 +38,7 @@
]


dataset_revisions = list(
_original_dataset_revisions = list(
{ # deduplicate as multiple tasks rely on the same dataset (save us at least 100 test cases)
(t.metadata.dataset["path"], t.metadata.dataset["revision"])
for t in mteb.get_tasks(exclude_superseded=False)
Expand All @@ -47,6 +48,15 @@
}
)

custom_revisions = os.getenv("CUSTOM_DATASET_REVISIONS")
if custom_revisions:
# Parse comma-separated list of "path:revision" pairs
dataset_revisions = [
tuple(pair.split(":", 1)) for pair in custom_revisions.split(",") if ":" in pair
]
else:
dataset_revisions = _original_dataset_revisions


@pytest.mark.parametrize("task", tasks)
@patch("datasets.load_dataset")
Expand Down
Loading