Skip to content

Commit 215303a

Browse files
muellerzrelusenji
authored andcommitted
Fix multiproc metrics in no_trainer examples (huggingface#16865)
1 parent 152738e commit 215303a

File tree

7 files changed

+80
-15
lines changed

7 files changed

+80
-15
lines changed

examples/pytorch/image-classification/run_image_classification_no_trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,21 @@ def collate_fn(examples):
457457
break
458458

459459
model.eval()
460+
samples_seen = 0
460461
for step, batch in enumerate(eval_dataloader):
461462
outputs = model(**batch)
462463
predictions = outputs.logits.argmax(dim=-1)
464+
predictions, references = accelerator.gather((predictions, batch["labels"]))
465+
# If we are in a multiprocess environment, the last batch has duplicates
466+
if accelerator.num_processes > 1:
467+
if step == len(eval_dataloader):
468+
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
469+
references = references[: len(eval_dataloader.dataset) - samples_seen]
470+
else:
471+
samples_seen += references.shape[0]
463472
metric.add_batch(
464-
predictions=accelerator.gather(predictions),
465-
references=accelerator.gather(batch["labels"]),
473+
predictions=predictions,
474+
references=references,
466475
)
467476

468477
eval_metric = metric.compute()

examples/pytorch/multiple-choice/run_swag_no_trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,13 +559,22 @@ def preprocess_function(examples):
559559
break
560560

561561
model.eval()
562+
samples_seen = 0
562563
for step, batch in enumerate(eval_dataloader):
563564
with torch.no_grad():
564565
outputs = model(**batch)
565566
predictions = outputs.logits.argmax(dim=-1)
567+
predictions, references = accelerator.gather((predictions, batch["labels"]))
568+
# If we are in a multiprocess environment, the last batch has duplicates
569+
if accelerator.num_processes > 1:
570+
if step == len(eval_dataloader):
571+
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
572+
references = references[: len(eval_dataloader.dataset) - samples_seen]
573+
else:
574+
samples_seen += references.shape[0]
566575
metric.add_batch(
567-
predictions=accelerator.gather(predictions),
568-
references=accelerator.gather(batch["labels"]),
576+
predictions=predictions,
577+
references=references,
569578
)
570579

571580
eval_metric = metric.compute()

examples/pytorch/semantic-segmentation/run_semantic_segmentation_no_trainer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def preprocess_val(example_batch):
567567

568568
logger.info("***** Running evaluation *****")
569569
model.eval()
570+
samples_seen = 0
570571
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
571572
outputs = model(**batch)
572573

@@ -575,9 +576,19 @@ def preprocess_val(example_batch):
575576
)
576577
predictions = upsampled_logits.argmax(dim=1)
577578

579+
predictions, references = accelerator.gather((predictions, batch["labels"]))
580+
581+
# If we are in a multiprocess environment, the last batch has duplicates
582+
if accelerator.num_processes > 1:
583+
if step == len(eval_dataloader):
584+
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
585+
references = references[: len(eval_dataloader.dataset) - samples_seen]
586+
else:
587+
samples_seen += references.shape[0]
588+
578589
metric.add_batch(
579-
predictions=accelerator.gather(predictions),
580-
references=accelerator.gather(batch["labels"]),
590+
predictions=predictions,
591+
references=references,
581592
)
582593

583594
eval_metrics = metric.compute(

examples/pytorch/summarization/run_summarization_no_trainer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,7 @@ def postprocess_text(preds, labels):
628628
"max_length": args.val_max_target_length if args is not None else config.max_length,
629629
"num_beams": args.num_beams,
630630
}
631+
samples_seen = 0
631632
for step, batch in enumerate(eval_dataloader):
632633
with torch.no_grad():
633634
generated_tokens = accelerator.unwrap_model(model).generate(
@@ -644,8 +645,9 @@ def postprocess_text(preds, labels):
644645
# If we did not pad to max length, we need to pad the labels too
645646
labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)
646647

647-
generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()
648-
labels = accelerator.gather(labels).cpu().numpy()
648+
generated_tokens, labels = accelerator.gather((generated_tokens, labels))
649+
generated_tokens = generated_tokens.cpu().numpy()
650+
labels = labels.cpu().numpy()
649651

650652
if args.ignore_pad_token_for_loss:
651653
# Replace -100 in the labels as we can't decode them.
@@ -656,8 +658,18 @@ def postprocess_text(preds, labels):
656658
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
657659

658660
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
659-
660-
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
661+
# If we are in a multiprocess environment, the last batch has duplicates
662+
if accelerator.num_processes > 1:
663+
if step == len(eval_dataloader):
664+
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
665+
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
666+
else:
667+
samples_seen += decoded_labels.shape[0]
668+
669+
metric.add_batch(
670+
predictions=decoded_preds,
671+
references=decoded_labels,
672+
)
661673
result = metric.compute(use_stemmer=True)
662674
# Extract a few results from ROUGE
663675
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

examples/pytorch/text-classification/run_glue_no_trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,12 +506,21 @@ def preprocess_function(examples):
506506
break
507507

508508
model.eval()
509+
samples_seen = 0
509510
for step, batch in enumerate(eval_dataloader):
510511
outputs = model(**batch)
511512
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
513+
predictions, references = accelerator.gather((predictions, batch["labels"]))
514+
# If we are in a multiprocess environment, the last batch has duplicates
515+
if accelerator.num_processes > 1:
516+
if step == len(eval_dataloader):
517+
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
518+
references = references[: len(eval_dataloader.dataset) - samples_seen]
519+
else:
520+
samples_seen += references.shape[0]
512521
metric.add_batch(
513-
predictions=accelerator.gather(predictions),
514-
references=accelerator.gather(batch["labels"]),
522+
predictions=predictions,
523+
references=references,
515524
)
516525

517526
eval_metric = metric.compute()

examples/pytorch/token-classification/run_ner_no_trainer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def compute_metrics():
658658
break
659659

660660
model.eval()
661+
samples_seen = 0
661662
for step, batch in enumerate(eval_dataloader):
662663
with torch.no_grad():
663664
outputs = model(**batch)
@@ -666,9 +667,14 @@ def compute_metrics():
666667
if not args.pad_to_max_length: # necessary to pad predictions and labels for being gathered
667668
predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
668669
labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
669-
670-
predictions_gathered = accelerator.gather(predictions)
671-
labels_gathered = accelerator.gather(labels)
670+
predictions_gathered, labels_gathered = accelerator.gather((predictions, labels))
671+
# If we are in a multiprocess environment, the last batch has duplicates
672+
if accelerator.num_processes > 1:
673+
if step == len(eval_dataloader):
674+
predictions_gathered = predictions_gathered[: len(eval_dataloader.dataset) - samples_seen]
675+
labels_gathered = labels_gathered[: len(eval_dataloader.dataset) - samples_seen]
676+
else:
677+
samples_seen += labels_gathered.shape[0]
672678
preds, refs = get_labels(predictions_gathered, labels_gathered)
673679
metric.add_batch(
674680
predictions=preds,

examples/pytorch/translation/run_translation_no_trainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ def postprocess_text(preds, labels):
613613
"max_length": args.val_max_target_length if args is not None else config.max_length,
614614
"num_beams": args.num_beams,
615615
}
616+
samples_seen = 0
616617
for step, batch in enumerate(eval_dataloader):
617618
with torch.no_grad():
618619
generated_tokens = accelerator.unwrap_model(model).generate(
@@ -641,6 +642,14 @@ def postprocess_text(preds, labels):
641642

642643
decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
643644

645+
# If we are in a multiprocess environment, the last batch has duplicates
646+
if accelerator.num_processes > 1:
647+
if step == len(eval_dataloader):
648+
decoded_preds = decoded_preds[: len(eval_dataloader.dataset) - samples_seen]
649+
decoded_labels = decoded_labels[: len(eval_dataloader.dataset) - samples_seen]
650+
else:
651+
samples_seen += decoded_labels.shape[0]
652+
644653
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
645654
eval_metric = metric.compute()
646655
logger.info({"bleu": eval_metric["score"]})

0 commit comments

Comments
 (0)