diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index ae50bd2ce90a..5534e6901fb6 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -43,6 +43,7 @@ Trainer, TrainingArguments, default_data_collator, + is_torch_tpu_available, set_seed, ) from transformers.testing_utils import CaptureLogger @@ -479,8 +480,10 @@ def compute_metrics(eval_preds): tokenizer=tokenizer, # Data collator will default to DataCollatorWithPadding, so we change it. data_collator=default_data_collator, - compute_metrics=compute_metrics if training_args.do_eval else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, + compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None, ) # Training diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 9926cccfae3a..7ceae8b17a8c 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -43,6 +43,7 @@ HfArgumentParser, Trainer, TrainingArguments, + is_torch_tpu_available, set_seed, ) from transformers.trainer_utils import get_last_checkpoint @@ -513,8 +514,10 @@ def compute_metrics(eval_preds): eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=data_collator, - compute_metrics=compute_metrics if training_args.do_eval else None, - preprocess_logits_for_metrics=preprocess_logits_for_metrics if training_args.do_eval else None, + compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, + preprocess_logits_for_metrics=preprocess_logits_for_metrics + if training_args.do_eval and not is_torch_tpu_available() + else None, ) # Training