diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 157e65d18352..08a0dd1619ed 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -488,7 +488,11 @@ def __init__( else: self.label_smoother = None - self.state = TrainerState() + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) + self.control = TrainerControl() # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then # returned to 0 every time flos need to be logged