Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 0 additions & 165 deletions mteb/tasks/Reranking/eng/MindSmallReranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,168 +56,3 @@ class MindSmallReranking(AbsTaskRetrieval):
language models can effectively improve the performance of news recommendation. The MIND dataset will be
available at https://msnews.github.io}.", }""",
)

def process_example(
self, example: dict, split: str, query_idx: int, subquery_idx: int
) -> dict: # Added subquery_idx parameter
"""Process a single example from the dataset."""
query = example["query"]
positive_docs = example["positive"]
negative_docs = example["negative"]

# Modified query_id to include subquery index
query_id = f"{split}_query{query_idx}_{subquery_idx}"

# Rest of the method remains the same
example_data = {
"query_id": query_id,
"query": query,
"doc_ids": [],
"doc_texts": [],
"relevance_scores": [],
}

def get_doc_hash(text: str) -> str:
import hashlib

return hashlib.md5(text.encode()).hexdigest()

# Process positive documents
for i, pos_doc in enumerate(positive_docs):
doc_hash = get_doc_hash(pos_doc)
if pos_doc in self.doc_text_to_id[split]:
doc_id = self.doc_text_to_id[split][pos_doc]
else:
formatted_i = str(i).zfill(5)
doc_id = f"apositive_{doc_hash}_{formatted_i}"
self.doc_text_to_id[split][pos_doc] = doc_id

example_data["doc_ids"].append(doc_id)
example_data["doc_texts"].append(pos_doc)
example_data["relevance_scores"].append(1)

# Process negative documents
for i, neg_doc in enumerate(negative_docs):
doc_hash = get_doc_hash(neg_doc)
if neg_doc in self.doc_text_to_id[split]:
doc_id = self.doc_text_to_id[split][neg_doc]
else:
formatted_i = str(i).zfill(5)
doc_id = f"negative_{doc_hash}_{formatted_i}"
self.doc_text_to_id[split][neg_doc] = doc_id

example_data["doc_ids"].append(doc_id)
example_data["doc_texts"].append(neg_doc)
example_data["relevance_scores"].append(0)

return example_data

def load_data(self, **kwargs):
"""Load and transform the dataset with efficient deduplication."""
if self.data_loaded:
return

# Call parent class method
super(AbsTaskRetrieval, self).load_data(**kwargs)

logging.info(
f"Transforming old format to standard format for {self.metadata.name}"
)

self.corpus = defaultdict(lambda: defaultdict(dict))
self.queries = defaultdict(lambda: defaultdict(dict))
self.relevant_docs = defaultdict(lambda: defaultdict(dict))
self.top_ranked = defaultdict(lambda: defaultdict(list))
self.doc_text_to_id = defaultdict(dict)

# Process each split
for split in self.dataset:
logging.info(f"Processing split {split}")

# Pre-allocate lists for batch processing
all_queries = []
all_positives = []
all_negatives = []
all_instance_indices = []
all_subquery_indices = []

# First pass: expand queries while maintaining relationships
current_instance_idx = 0
for instance in tqdm.tqdm(self.dataset[split]):
queries = instance["query"]
positives = instance.get("positive", [])
negatives = instance.get("negative", [])

# For each query in this instance
for subquery_idx, query in enumerate(queries):
all_queries.append(query)
all_positives.append(positives) # Same positives for each subquery
all_negatives.append(negatives) # Same negatives for each subquery
all_instance_indices.append(current_instance_idx)
all_subquery_indices.append(subquery_idx)

current_instance_idx += 1

# Filter valid examples
valid_examples = []
valid_instance_indices = []
valid_subquery_indices = []

# Filter while maintaining relationships
for idx, (pos, neg) in enumerate(zip(all_positives, all_negatives)):
if len(pos) > 0 and len(neg) > 0:
valid_examples.append(idx)
valid_instance_indices.append(all_instance_indices[idx])
valid_subquery_indices.append(all_subquery_indices[idx])

total_instances = len(set(all_instance_indices))
valid_unique_instances = len(set(valid_instance_indices))
logging.info(
f"Found {total_instances} total instances, {valid_unique_instances} valid instances"
)
logging.info(
f"Filtered {len(all_queries) - len(valid_examples)} invalid examples. {len(valid_examples)} remaining."
)

# Process valid examples in batches
batch_size = 1000
for batch_start in tqdm.tqdm(range(0, len(valid_examples), batch_size)):
batch_end = min(batch_start + batch_size, len(valid_examples))
batch_indices = valid_examples[batch_start:batch_end]

# Process batch
for i, example_idx in enumerate(batch_indices):
instance_idx = valid_instance_indices[batch_start + i]
subquery_idx = valid_subquery_indices[batch_start + i]

example_data = self.process_example(
{
"query": all_queries[example_idx],
"positive": all_positives[example_idx],
"negative": all_negatives[example_idx],
},
split,
instance_idx,
subquery_idx,
)

# Populate data structures
query_id = example_data["query_id"]
self.queries[split][query_id] = example_data["query"]

for doc_id, doc_text, relevance in zip(
example_data["doc_ids"],
example_data["doc_texts"],
example_data["relevance_scores"],
):
if doc_id not in self.corpus[split]:
self.corpus[split][doc_id] = {
"text": doc_text,
"_id": doc_id,
}

self.top_ranked[split][query_id].append(doc_id)
self.relevant_docs[split][query_id][doc_id] = relevance

self.instructions = None
self.data_loaded = True
18 changes: 12 additions & 6 deletions mteb/tasks/Retrieval/ara/SadeemQuestionRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,23 @@ def load_data(self, **kwargs):
if self.data_loaded:
return

query_list = datasets.load_dataset(**self.metadata.dataset)["queries"]
query_list = datasets.load_dataset(**self.metadata.dataset, split="queries")
queries = {row["query-id"]: row["text"] for row in query_list}

corpus_list = datasets.load_dataset(**self.metadata.dataset)["corpus"]
corpus_list = datasets.load_dataset(**self.metadata.dataset, split="corpus")
corpus = {row["corpus-id"]: {"text": row["text"]} for row in corpus_list}

qrels_list = datasets.load_dataset(**self.metadata.dataset)["qrels"]
qrels_list = datasets.load_dataset(**self.metadata.dataset, split="qrels")
qrels = {row["query-id"]: {row["corpus-id"]: 1} for row in qrels_list}

self.corpus = {self._EVAL_SPLIT: corpus}
self.queries = {self._EVAL_SPLIT: queries}
self.relevant_docs = {self._EVAL_SPLIT: qrels}
self.dataset = {
"default": {
self._EVAL_SPLIT: {
"corpus": corpus,
"queries": queries,
"relevant_docs": qrels,
}
}
}

self.data_loaded = True