Skip to content

Commit

Permalink
fix embedding random
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Dec 31, 2024
1 parent 3aa9f4c commit ffc4a32
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions paddlenlp/datasets/embedding_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class Pair:
class EmbeddingDatasetMixin:
"""EmbeddingDatasetMixin."""

rng = random.Random(10) # random seed

def convert_example(tokenizer, example):
"""Convert raw json format example to Example."""

Expand Down Expand Up @@ -120,7 +122,8 @@ def _process_truncation(self, tokens, text_type):
def _postprocess_sequence(self, example: Example):
"""Post process sequence: tokenization & truncation."""
query = example.query
pos_passage = random.choice(example.pos_passage)
pos_passage = self.rng.choice(example.pos_passage)
pos_passage = example.pos_passage[0]
neg_passage = example.neg_passage
if len(neg_passage) > 0:
if len(neg_passage) < self.group_size - 1:
Expand All @@ -132,12 +135,12 @@ def _postprocess_sequence(self, example: Example):
selected_neg_passage = neg_passage * full_sets_needed

# Ensure the remainder part is filled; randomly select from neg_passage
selected_neg_passage += random.sample(neg_passage, remainder)
selected_neg_passage += self.rng.sample(neg_passage, remainder)

# Shuffle the result to ensure randomness
random.shuffle(selected_neg_passage)
self.rng.shuffle(selected_neg_passage)
else:
selected_neg_passage = random.sample(neg_passage, self.group_size - 1)
selected_neg_passage = self.rng.sample(neg_passage, self.group_size - 1)
else:
selected_neg_passage = []
# Process query tokens
Expand Down Expand Up @@ -240,13 +243,11 @@ def __iter__(self):
def iter_one_epoch(self):
"""Iterates through one epoch of the dataset."""

num_sequences = 0
for index, example in enumerate(self.example_dataset):
for _, example in enumerate(self.example_dataset):
example = self.convert_example(example)
sequence = self._postprocess_sequence(example)
if sequence is None:
continue
num_sequences += 1
yield [sequence]

self.epoch_index += 1

0 comments on commit ffc4a32

Please sign in to comment.