Skip to content

Commit 518bd02

Browse files
[Generation] Fix Transition probs (huggingface#17311)
* [Draft] fix transition probs * up * up * up * make it work * fix * finish * update
1 parent e8714c0 commit 518bd02

File tree

4 files changed

+197
-75
lines changed

4 files changed

+197
-75
lines changed

src/transformers/generation_beam_search.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def process(
212212
next_indices: torch.LongTensor,
213213
pad_token_id: Optional[int] = None,
214214
eos_token_id: Optional[int] = None,
215+
beam_indices: Optional[torch.LongTensor] = None,
215216
) -> Tuple[torch.Tensor]:
216217
cur_len = input_ids.shape[-1]
217218
batch_size = len(self._beam_hyps)
@@ -256,9 +257,16 @@ def process(
256257
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
257258
if is_beam_token_worse_than_top_num_beams:
258259
continue
260+
if beam_indices is not None:
261+
beam_index = beam_indices[batch_beam_idx]
262+
beam_index = beam_index + (next_index,)
263+
else:
264+
beam_index = None
265+
259266
beam_hyp.add(
260267
input_ids[batch_beam_idx].clone(),
261268
next_score.item(),
269+
beam_indices=beam_index,
262270
)
263271
else:
264272
# add next predicted token since it is not eos_token
@@ -299,6 +307,7 @@ def finalize(
299307
max_length: int,
300308
pad_token_id: Optional[int] = None,
301309
eos_token_id: Optional[int] = None,
310+
beam_indices: Optional[torch.LongTensor] = None,
302311
) -> Tuple[torch.LongTensor]:
303312
batch_size = len(self._beam_hyps)
304313

@@ -313,11 +322,13 @@ def finalize(
313322
batch_beam_idx = batch_idx * self.num_beams + beam_id
314323
final_score = final_beam_scores[batch_beam_idx].item()
315324
final_tokens = input_ids[batch_beam_idx]
316-
beam_hyp.add(final_tokens, final_score)
325+
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
326+
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
317327

318328
# select the best hypotheses
319329
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
320330
best = []
331+
best_indices = []
321332
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
322333

323334
# retrieve best hypotheses
@@ -327,30 +338,50 @@ def finalize(
327338
best_hyp_tuple = sorted_hyps.pop()
328339
best_score = best_hyp_tuple[0]
329340
best_hyp = best_hyp_tuple[1]
341+
best_index = best_hyp_tuple[2]
330342
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
331343

332-
# append to lists
344+
# append hyp to lists
333345
best.append(best_hyp)
346+
347+
# append indices to list
348+
best_indices.append(best_index)
349+
334350
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
335351

336352
# prepare for adding eos
337353
sent_lengths_max = sent_lengths.max().item() + 1
338354
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
339355
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
356+
357+
if len(best_indices) > 0 and best_indices[0] is not None:
358+
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
359+
else:
360+
indices = None
361+
340362
# shorter batches are padded if needed
341363
if sent_lengths.min().item() != sent_lengths.max().item():
342364
assert pad_token_id is not None, "`pad_token_id` has to be defined"
343365
decoded.fill_(pad_token_id)
366+
367+
if indices is not None:
368+
indices.fill_(-1)
369+
344370
# fill with hypotheses and eos_token_id if the latter fits in
345-
for i, hypo in enumerate(best):
371+
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
346372
decoded[i, : sent_lengths[i]] = hypo
373+
374+
if indices is not None:
375+
indices[i, : len(best_idx)] = torch.tensor(best_idx)
376+
347377
if sent_lengths[i] < sent_max_len:
348378
decoded[i, sent_lengths[i]] = eos_token_id
349379

350380
return UserDict(
351381
{
352382
"sequences": decoded,
353383
"sequence_scores": best_scores,
384+
"beam_indices": indices,
354385
}
355386
)
356387

@@ -789,6 +820,7 @@ def finalize(
789820

790821
# prepare for adding eos
791822
sent_lengths_max = sent_lengths.max().item() + 1
823+
792824
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
793825
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
794826
# shorter batches are padded if needed
@@ -801,6 +833,7 @@ def finalize(
801833
decoded[i, : sent_lengths[i]] = hypo
802834
if sent_lengths[i] < sent_max_len:
803835
decoded[i, sent_lengths[i]] = eos_token_id
836+
804837
return UserDict(
805838
{
806839
"sequences": decoded,
@@ -826,15 +859,15 @@ def __len__(self):
826859
"""
827860
return len(self.beams)
828861

829-
def add(self, hyp: torch.LongTensor, sum_logprobs: float):
862+
def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
830863
"""
831864
Add a new hypothesis to the list.
832865
"""
833866
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
834867
if len(self) < self.num_beams or score > self.worst_score:
835-
self.beams.append((score, hyp))
868+
self.beams.append((score, hyp, beam_indices))
836869
if len(self) > self.num_beams:
837-
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
870+
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
838871
del self.beams[sorted_next_scores[0][1]]
839872
self.worst_score = sorted_next_scores[1][0]
840873
else:

0 commit comments

Comments
 (0)