Skip to content

Commit 3a3bb65

Browse files
authored
[Embedding] Fix embedding random (#9721)
1 parent 3aa9f4c commit 3a3bb65

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

paddlenlp/datasets/embedding_dataset.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,10 @@ def _process_truncation(self, tokens, text_type):
117117
pos_ids = list(range(len(token_ids)))
118118
return token_ids, pos_ids
119119

120-
def _postprocess_sequence(self, example: Example):
120+
def _postprocess_sequence(self, example: Example, rng):
121121
"""Post process sequence: tokenization & truncation."""
122122
query = example.query
123-
pos_passage = random.choice(example.pos_passage)
123+
pos_passage = rng.choice(example.pos_passage)
124124
neg_passage = example.neg_passage
125125
if len(neg_passage) > 0:
126126
if len(neg_passage) < self.group_size - 1:
@@ -132,12 +132,12 @@ def _postprocess_sequence(self, example: Example):
132132
selected_neg_passage = neg_passage * full_sets_needed
133133

134134
# Ensure the remainder part is filled; randomly select from neg_passage
135-
selected_neg_passage += random.sample(neg_passage, remainder)
135+
selected_neg_passage += rng.sample(neg_passage, remainder)
136136

137137
# Shuffle the result to ensure randomness
138-
random.shuffle(selected_neg_passage)
138+
rng.shuffle(selected_neg_passage)
139139
else:
140-
selected_neg_passage = random.sample(neg_passage, self.group_size - 1)
140+
selected_neg_passage = rng.sample(neg_passage, self.group_size - 1)
141141
else:
142142
selected_neg_passage = []
143143
# Process query tokens
@@ -241,9 +241,11 @@ def iter_one_epoch(self):
241241
"""Iterates through one epoch of the dataset."""
242242

243243
num_sequences = 0
244-
for index, example in enumerate(self.example_dataset):
244+
rng = random.Random()
245+
for _, example in enumerate(self.example_dataset):
245246
example = self.convert_example(example)
246-
sequence = self._postprocess_sequence(example)
247+
rng.seed(num_sequences)
248+
sequence = self._postprocess_sequence(example, rng)
247249
if sequence is None:
248250
continue
249251
num_sequences += 1

0 commit comments

Comments
 (0)