Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/create_tasks_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def create_task_lang_table(tasks: list[mteb.AbsTask], sort_by_sum=False) -> str:
if lang in PROGRAMMING_LANGS:
lang = "code"
if table_dict.get(lang) is None:
table_dict[lang] = {k: 0 for k in sorted(get_args(TASK_TYPE))}
table_dict[lang] = dict.fromkeys(sorted(get_args(TASK_TYPE)), 0)
table_dict[lang][task.metadata.type] += 1

## Wrangle for polars
Expand Down
2 changes: 1 addition & 1 deletion mteb/abstasks/stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _prepare_stratification(self, y: np.ndarray) -> tuple:
[self.percentage_per_fold[i] * self.n_samples for i in range(self.n_splits)]
)
rows = sp.lil_matrix(y).rows
rows_used = {i: False for i in range(self.n_samples)}
rows_used = dict.fromkeys(range(self.n_samples), False)
all_combinations = []
per_row_combinations = [[] for i in range(self.n_samples)]
samples_with_combination = {}
Expand Down
2 changes: 1 addition & 1 deletion mteb/evaluation/evaluators/RetrievalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def search_cross_encoder(
logging.info(
f"previous_results is None. Using all the documents to rerank: {len(corpus)}"
)
q_results = {doc_id: 0.0 for doc_id in corpus.keys()}
q_results = dict.fromkeys(corpus.keys(), 0.0)
else:
q_results = self.previous_results[qid]
# take the top-k only
Expand Down
2 changes: 1 addition & 1 deletion mteb/leaderboard/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def apply_styling(
joint_table[score_columns] = joint_table[score_columns].map(format_scores)
joint_table_style = joint_table.style.format(
{
**{column: "{:.2f}" for column in score_columns},
**dict.fromkeys(score_columns, "{:.2f}"),
"Rank (Borda)": "{:.0f}",
},
na_rep="",
Expand Down
3 changes: 1 addition & 2 deletions mteb/task_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def borda_count(
results = results.to_legacy_dict()
n_candidates = sum(len(revs) for revs in results.values())
candidate_scores = {
model: {revision: 0.0 for revision in revisions}
for model, revisions in results.items()
model: dict.fromkeys(revisions, 0.0) for model, revisions in results.items()
}

tasks = defaultdict(list) # {task_name: [(model, revision, score), ...]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
def _load_wit_data(
path: str, langs: list, splits: str, cache_dir: str = None, revision: str = None
):
corpus = {lang: {split: None for split in splits} for lang in langs}
queries = {lang: {split: None for split in splits} for lang in langs}
relevant_docs = {lang: {split: None for split in splits} for lang in langs}
corpus = {lang: dict.fromkeys(splits) for lang in langs}
queries = {lang: dict.fromkeys(splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(splits) for lang in langs}

split = "test"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
def _load_xflickrco_data(
path: str, langs: list, splits: str, cache_dir: str = None, revision: str = None
):
corpus = {lang: {split: None for split in splits} for lang in langs}
queries = {lang: {split: None for split in splits} for lang in langs}
relevant_docs = {lang: {split: None for split in splits} for lang in langs}
corpus = {lang: dict.fromkeys(splits) for lang in langs}
queries = {lang: dict.fromkeys(splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(splits) for lang in langs}

split = "test"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@
def _load_xm3600_data(
path: str, langs: list, splits: str, cache_dir: str = None, revision: str = None
):
corpus = {lang: {split: None for split in splits} for lang in langs}
queries = {lang: {split: None for split in splits} for lang in langs}
relevant_docs = {lang: {split: None for split in splits} for lang in langs}
corpus = {lang: dict.fromkeys(splits) for lang in langs}
queries = {lang: dict.fromkeys(splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(splits) for lang in langs}

split = "test"

Expand Down
4 changes: 1 addition & 3 deletions mteb/tasks/Retrieval/dan/TwitterHjerneRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def dataset_transform(self) -> None:
answer_id = str(text2id[a])
answer_ids.append(answer_id)

self.relevant_docs[split][query_id] = {
answer_id: 1 for answer_id in answer_ids
}
self.relevant_docs[split][query_id] = dict.fromkeys(answer_ids, 1)


def answers_to_list(example: dict) -> dict:
Expand Down
3 changes: 2 additions & 1 deletion mteb/tasks/Retrieval/deu/GerDaLIRRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def load_data(self, **kwargs):
self.corpus = {self._EVAL_SPLIT: {row["_id"]: row for row in corpus_rows}}
self.relevant_docs = {
self._EVAL_SPLIT: {
row["_id"]: {v: 1 for v in row["text"].split(" ")} for row in qrels_rows
row["_id"]: dict.fromkeys(row["text"].split(" "), 1)
for row in qrels_rows
}
}

Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/deu/GermanDPRRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def load_data(self, **kwargs):
existing_docs=all_docs,
)
corpus.update(neg_docs)
relevant_docs[q_id] = {k: 1 for k in pos_docs}
relevant_docs[q_id] = dict.fromkeys(pos_docs, 1)
corpus = {
key: doc.get("title", "") + " " + doc["text"] for key, doc in corpus.items()
}
Expand Down
8 changes: 3 additions & 5 deletions mteb/tasks/Retrieval/eng/BrightRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ def load_bright_data(
cache_dir: str | None = None,
revision: str | None = None,
):
corpus = {domain: {split: None for split in eval_splits} for domain in DOMAINS}
queries = {domain: {split: None for split in eval_splits} for domain in DOMAINS}
relevant_docs = {
domain: {split: None for split in eval_splits} for domain in DOMAINS
}
corpus = {domain: dict.fromkeys(eval_splits) for domain in DOMAINS}
queries = {domain: dict.fromkeys(eval_splits) for domain in DOMAINS}
relevant_docs = {domain: dict.fromkeys(eval_splits) for domain in DOMAINS}

for domain in domains:
domain_corpus = datasets.load_dataset(
Expand Down
12 changes: 3 additions & 9 deletions mteb/tasks/Retrieval/multilingual/CUREv1Retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,9 @@ def load_data(self, **kwargs):
cache_dir = kwargs.get("cache_dir", None)

# Iterate over splits and languages
corpus = {
language: {split: None for split in eval_splits} for language in languages
}
queries = {
language: {split: None for split in eval_splits} for language in languages
}
relevant_docs = {
language: {split: None for split in eval_splits} for language in languages
}
corpus = {language: dict.fromkeys(eval_splits) for language in languages}
queries = {language: dict.fromkeys(eval_splits) for language in languages}
relevant_docs = {language: dict.fromkeys(eval_splits) for language in languages}
for split in eval_splits:
# Since this is a cross-lingual dataset, the corpus and the relevant documents do not depend on the language
split_corpus = self._load_corpus(split=split, cache_dir=cache_dir)
Expand Down
12 changes: 6 additions & 6 deletions mteb/tasks/Retrieval/multilingual/MIRACLRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def _load_miracl_data(
revision: str | None = None,
trust_remote_code: bool = False,
):
corpus = {lang: {split: None for split in splits} for lang in langs}
queries = {lang: {split: None for split in splits} for lang in langs}
relevant_docs = {lang: {split: None for split in splits} for lang in langs}
corpus = {lang: dict.fromkeys(splits) for lang in langs}
queries = {lang: dict.fromkeys(splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(splits) for lang in langs}

split = _EVAL_SPLIT

Expand Down Expand Up @@ -170,9 +170,9 @@ def _load_miracl_data_hard_negatives(
revision: str | None = None,
trust_remote_code: bool = False,
) -> tuple:
corpus = {lang: {split: None for split in splits} for lang in langs}
queries = {lang: {split: None for split in splits} for lang in langs}
relevant_docs = {lang: {split: None for split in splits} for lang in langs}
corpus = {lang: dict.fromkeys(splits) for lang in langs}
queries = {lang: dict.fromkeys(splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(splits) for lang in langs}

split = _EVAL_SPLIT

Expand Down
6 changes: 3 additions & 3 deletions mteb/tasks/Retrieval/multilingual/MultiLongDocRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def load_mldr_data(
cache_dir: str = None,
revision: str = None,
):
corpus = {lang: {split: None for split in eval_splits} for lang in langs}
queries = {lang: {split: None for split in eval_splits} for lang in langs}
relevant_docs = {lang: {split: None for split in eval_splits} for lang in langs}
corpus = {lang: dict.fromkeys(eval_splits) for lang in langs}
queries = {lang: dict.fromkeys(eval_splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(eval_splits) for lang in langs}

for lang in langs:
lang_corpus = datasets.load_dataset(
Expand Down
12 changes: 6 additions & 6 deletions mteb/tasks/Retrieval/multilingual/NeuCLIR2022Retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def load_neuclir_data(
cache_dir: str | None = None,
revision: str | None = None,
):
corpus = {lang: {split: None for split in eval_splits} for lang in langs}
queries = {lang: {split: None for split in eval_splits} for lang in langs}
relevant_docs = {lang: {split: None for split in eval_splits} for lang in langs}
corpus = {lang: dict.fromkeys(eval_splits) for lang in langs}
queries = {lang: dict.fromkeys(eval_splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(eval_splits) for lang in langs}

for lang in langs:
lang_corpus = datasets.load_dataset(
Expand Down Expand Up @@ -112,9 +112,9 @@ def load_neuclir_data_hard_negatives(
revision: str | None = None,
):
split = "test"
corpus = {lang: {split: None for split in eval_splits} for lang in langs}
queries = {lang: {split: None for split in eval_splits} for lang in langs}
relevant_docs = {lang: {split: None for split in eval_splits} for lang in langs}
corpus = {lang: dict.fromkeys(eval_splits) for lang in langs}
queries = {lang: dict.fromkeys(eval_splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(eval_splits) for lang in langs}

for lang in langs:
corpus_identifier = f"corpus-{lang}"
Expand Down
12 changes: 6 additions & 6 deletions mteb/tasks/Retrieval/multilingual/NeuCLIR2023Retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def load_neuclir_data(
cache_dir: str | None = None,
revision: str | None = None,
):
corpus = {lang: {split: None for split in eval_splits} for lang in langs}
queries = {lang: {split: None for split in eval_splits} for lang in langs}
relevant_docs = {lang: {split: None for split in eval_splits} for lang in langs}
corpus = {lang: dict.fromkeys(eval_splits) for lang in langs}
queries = {lang: dict.fromkeys(eval_splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(eval_splits) for lang in langs}

for lang in langs:
lang_corpus = datasets.load_dataset(
Expand Down Expand Up @@ -113,9 +113,9 @@ def load_neuclir_data_hard_negatives(
revision: str | None = None,
):
split = "test"
corpus = {lang: {split: None for split in eval_splits} for lang in langs}
queries = {lang: {split: None for split in eval_splits} for lang in langs}
relevant_docs = {lang: {split: None for split in eval_splits} for lang in langs}
corpus = {lang: dict.fromkeys(eval_splits) for lang in langs}
queries = {lang: dict.fromkeys(eval_splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(eval_splits) for lang in langs}

for lang in langs:
corpus_identifier = f"corpus-{lang}"
Expand Down
6 changes: 3 additions & 3 deletions mteb/tasks/Retrieval/multilingual/WebFAQRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@
def _load_webfaq_data(
path: str, langs: list, splits: str, cache_dir: str = None, revision: str = None
):
corpus = {lang: {split: None for split in splits} for lang in langs}
queries = {lang: {split: None for split in splits} for lang in langs}
relevant_docs = {lang: {split: None for split in splits} for lang in langs}
corpus = {lang: dict.fromkeys(splits) for lang in langs}
queries = {lang: dict.fromkeys(splits) for lang in langs}
relevant_docs = {lang: dict.fromkeys(splits) for lang in langs}

split = _EVAL_SPLIT

Expand Down
2 changes: 1 addition & 1 deletion mteb/tasks/Retrieval/multilingual/XMarketRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _load_xmarket_data(
corpus[lang][split] = {row["_id"]: row for row in corpus_rows}
queries[lang][split] = {row["_id"]: row["text"] for row in query_rows}
relevant_docs[lang][split] = {
row["_id"]: {v: 1 for v in row["text"].split(" ")} for row in qrels_rows
row["_id"]: dict.fromkeys(row["text"].split(" "), 1) for row in qrels_rows
}

corpus = datasets.DatasetDict(corpus)
Expand Down
3 changes: 2 additions & 1 deletion mteb/tasks/Retrieval/spa/SpanishPassageRetrievalS2P.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def load_data(self, **kwargs):
self.corpus = {"test": {row["_id"]: row for row in corpus_rows}}
self.relevant_docs = {
"test": {
row["_id"]: {v: 1 for v in row["text"].split(" ")} for row in qrels_rows
row["_id"]: dict.fromkeys(row["text"].split(" "), 1)
for row in qrels_rows
}
}

Expand Down
3 changes: 2 additions & 1 deletion mteb/tasks/Retrieval/spa/SpanishPassageRetrievalS2S.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def load_data(self, **kwargs):
self.corpus = {"test": {row["_id"]: row for row in corpus_rows}}
self.relevant_docs = {
"test": {
row["_id"]: {v: 1 for v in row["text"].split(" ")} for row in qrels_rows
row["_id"]: dict.fromkeys(row["text"].split(" "), 1)
for row in qrels_rows
}
}

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ mteb = "mteb.cli:main"
[project.optional-dependencies]
image = ["torchvision>0.2.1"]
dev = [
"ruff==0.9.7", # locked so we don't get PRs which fail only due to a lint update
"ruff==0.11.13", # locked so we don't get PRs which fail only due to a lint update
"pytest>=8.3.4",
"pytest-xdist>=3.6.1",
"pytest-coverage>=0.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_reproducible_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_validate_task_to_prompt_name(task_name: str | mteb.AbsTask):
else:
task_names = [task_name]

model_prompts = {task_name: "prompt_name" for task_name in task_names}
model_prompts = dict.fromkeys(task_names, "prompt_name")
model_prompts |= {task_name + "-query": "prompt_name" for task_name in task_names}
model_prompts |= {task_name + "-passage": "prompt_name" for task_name in task_names}
model_prompts |= {
Expand Down