diff --git a/examples/pytorch/multiple-choice/run_swag_no_trainer.py b/examples/pytorch/multiple-choice/run_swag_no_trainer.py index 7d5d7588c694..a3868434b285 100755 --- a/examples/pytorch/multiple-choice/run_swag_no_trainer.py +++ b/examples/pytorch/multiple-choice/run_swag_no_trainer.py @@ -592,19 +592,11 @@ def preprocess_function(examples): break model.eval() - samples_seen = 0 for step, batch in enumerate(eval_dataloader): with torch.no_grad(): outputs = model(**batch) predictions = outputs.logits.argmax(dim=-1) - predictions, references = accelerator.gather((predictions, batch["labels"])) - # If we are in a multiprocess environment, the last batch has duplicates - if accelerator.num_processes > 1: - if step == len(eval_dataloader) - 1: - predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] - references = references[: len(eval_dataloader.dataset) - samples_seen] - else: - samples_seen += references.shape[0] + predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"])) metric.add_batch( predictions=predictions, references=references,