@@ -212,6 +212,7 @@ def process(
212
212
next_indices : torch .LongTensor ,
213
213
pad_token_id : Optional [int ] = None ,
214
214
eos_token_id : Optional [int ] = None ,
215
+ beam_indices : Optional [torch .LongTensor ] = None ,
215
216
) -> Tuple [torch .Tensor ]:
216
217
cur_len = input_ids .shape [- 1 ]
217
218
batch_size = len (self ._beam_hyps )
@@ -256,9 +257,16 @@ def process(
256
257
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self .group_size
257
258
if is_beam_token_worse_than_top_num_beams :
258
259
continue
260
+ if beam_indices is not None :
261
+ beam_index = beam_indices [batch_beam_idx ]
262
+ beam_index = beam_index + (next_index ,)
263
+ else :
264
+ beam_index = None
265
+
259
266
beam_hyp .add (
260
267
input_ids [batch_beam_idx ].clone (),
261
268
next_score .item (),
269
+ beam_indices = beam_index ,
262
270
)
263
271
else :
264
272
# add next predicted token since it is not eos_token
@@ -299,6 +307,7 @@ def finalize(
299
307
max_length : int ,
300
308
pad_token_id : Optional [int ] = None ,
301
309
eos_token_id : Optional [int ] = None ,
310
+ beam_indices : Optional [torch .LongTensor ] = None ,
302
311
) -> Tuple [torch .LongTensor ]:
303
312
batch_size = len (self ._beam_hyps )
304
313
@@ -313,11 +322,13 @@ def finalize(
313
322
batch_beam_idx = batch_idx * self .num_beams + beam_id
314
323
final_score = final_beam_scores [batch_beam_idx ].item ()
315
324
final_tokens = input_ids [batch_beam_idx ]
316
- beam_hyp .add (final_tokens , final_score )
325
+ beam_index = beam_indices [batch_beam_idx ] if beam_indices is not None else None
326
+ beam_hyp .add (final_tokens , final_score , beam_indices = beam_index )
317
327
318
328
# select the best hypotheses
319
329
sent_lengths = input_ids .new (batch_size * self .num_beam_hyps_to_keep )
320
330
best = []
331
+ best_indices = []
321
332
best_scores = torch .zeros (batch_size * self .num_beam_hyps_to_keep , device = self .device , dtype = torch .float32 )
322
333
323
334
# retrieve best hypotheses
@@ -327,30 +338,50 @@ def finalize(
327
338
best_hyp_tuple = sorted_hyps .pop ()
328
339
best_score = best_hyp_tuple [0 ]
329
340
best_hyp = best_hyp_tuple [1 ]
341
+ best_index = best_hyp_tuple [2 ]
330
342
sent_lengths [self .num_beam_hyps_to_keep * i + j ] = len (best_hyp )
331
343
332
- # append to lists
344
+ # append hyp to lists
333
345
best .append (best_hyp )
346
+
347
+ # append indices to list
348
+ best_indices .append (best_index )
349
+
334
350
best_scores [i * self .num_beam_hyps_to_keep + j ] = best_score
335
351
336
352
# prepare for adding eos
337
353
sent_lengths_max = sent_lengths .max ().item () + 1
338
354
sent_max_len = min (sent_lengths_max , max_length ) if max_length is not None else sent_lengths_max
339
355
decoded : torch .LongTensor = input_ids .new (batch_size * self .num_beam_hyps_to_keep , sent_max_len )
356
+
357
+ if len (best_indices ) > 0 and best_indices [0 ] is not None :
358
+ indices : torch .LongTensor = input_ids .new (batch_size * self .num_beam_hyps_to_keep , sent_max_len )
359
+ else :
360
+ indices = None
361
+
340
362
# shorter batches are padded if needed
341
363
if sent_lengths .min ().item () != sent_lengths .max ().item ():
342
364
assert pad_token_id is not None , "`pad_token_id` has to be defined"
343
365
decoded .fill_ (pad_token_id )
366
+
367
+ if indices is not None :
368
+ indices .fill_ (- 1 )
369
+
344
370
# fill with hypotheses and eos_token_id if the latter fits in
345
- for i , hypo in enumerate (best ):
371
+ for i , ( hypo , best_idx ) in enumerate (zip ( best , best_indices ) ):
346
372
decoded [i , : sent_lengths [i ]] = hypo
373
+
374
+ if indices is not None :
375
+ indices [i , : len (best_idx )] = torch .tensor (best_idx )
376
+
347
377
if sent_lengths [i ] < sent_max_len :
348
378
decoded [i , sent_lengths [i ]] = eos_token_id
349
379
350
380
return UserDict (
351
381
{
352
382
"sequences" : decoded ,
353
383
"sequence_scores" : best_scores ,
384
+ "beam_indices" : indices ,
354
385
}
355
386
)
356
387
@@ -789,6 +820,7 @@ def finalize(
789
820
790
821
# prepare for adding eos
791
822
sent_lengths_max = sent_lengths .max ().item () + 1
823
+
792
824
sent_max_len = min (sent_lengths_max , max_length ) if max_length is not None else sent_lengths_max
793
825
decoded : torch .LongTensor = input_ids .new (batch_size * self .num_beam_hyps_to_keep , sent_max_len )
794
826
# shorter batches are padded if needed
@@ -801,6 +833,7 @@ def finalize(
801
833
decoded [i , : sent_lengths [i ]] = hypo
802
834
if sent_lengths [i ] < sent_max_len :
803
835
decoded [i , sent_lengths [i ]] = eos_token_id
836
+
804
837
return UserDict (
805
838
{
806
839
"sequences" : decoded ,
@@ -826,15 +859,15 @@ def __len__(self):
826
859
"""
827
860
return len (self .beams )
828
861
829
- def add (self , hyp : torch .LongTensor , sum_logprobs : float ):
862
+ def add (self , hyp : torch .LongTensor , sum_logprobs : float , beam_indices : Optional [ torch . LongTensor ] = None ):
830
863
"""
831
864
Add a new hypothesis to the list.
832
865
"""
833
866
score = sum_logprobs / (hyp .shape [- 1 ] ** self .length_penalty )
834
867
if len (self ) < self .num_beams or score > self .worst_score :
835
- self .beams .append ((score , hyp ))
868
+ self .beams .append ((score , hyp , beam_indices ))
836
869
if len (self ) > self .num_beams :
837
- sorted_next_scores = sorted ([(s , idx ) for idx , (s , _ ) in enumerate (self .beams )])
870
+ sorted_next_scores = sorted ([(s , idx ) for idx , (s , _ , _ ) in enumerate (self .beams )])
838
871
del self .beams [sorted_next_scores [0 ][1 ]]
839
872
self .worst_score = sorted_next_scores [1 ][0 ]
840
873
else :
0 commit comments