@@ -117,10 +117,10 @@ def _process_truncation(self, tokens, text_type):
117117 pos_ids = list (range (len (token_ids )))
118118 return token_ids , pos_ids
119119
120- def _postprocess_sequence (self , example : Example ):
120+ def _postprocess_sequence (self , example : Example , rng ):
121121 """Post process sequence: tokenization & truncation."""
122122 query = example .query
123- pos_passage = random .choice (example .pos_passage )
123+ pos_passage = rng .choice (example .pos_passage )
124124 neg_passage = example .neg_passage
125125 if len (neg_passage ) > 0 :
126126 if len (neg_passage ) < self .group_size - 1 :
@@ -132,12 +132,12 @@ def _postprocess_sequence(self, example: Example):
132132 selected_neg_passage = neg_passage * full_sets_needed
133133
134134 # Ensure the remainder part is filled; randomly select from neg_passage
135- selected_neg_passage += random .sample (neg_passage , remainder )
135+ selected_neg_passage += rng .sample (neg_passage , remainder )
136136
137137 # Shuffle the result to ensure randomness
138- random .shuffle (selected_neg_passage )
138+ rng .shuffle (selected_neg_passage )
139139 else :
140- selected_neg_passage = random .sample (neg_passage , self .group_size - 1 )
140+ selected_neg_passage = rng .sample (neg_passage , self .group_size - 1 )
141141 else :
142142 selected_neg_passage = []
143143 # Process query tokens
@@ -241,9 +241,11 @@ def iter_one_epoch(self):
241241 """Iterates through one epoch of the dataset."""
242242
243243 num_sequences = 0
244- for index , example in enumerate (self .example_dataset ):
244+ rng = random .Random ()
245+ for _ , example in enumerate (self .example_dataset ):
245246 example = self .convert_example (example )
246- sequence = self ._postprocess_sequence (example )
247+ rng .seed (num_sequences )
248+ sequence = self ._postprocess_sequence (example , rng )
247249 if sequence is None :
248250 continue
249251 num_sequences += 1
0 commit comments