diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 67a796e713e6..32b1f0059063 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1801,7 +1801,9 @@ def _no_sync_in_gradient_accumulation(self): """ Whether or not to use no_sync for the gradients when doing gradient accumulation. """ - return not (self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled()) + return not ( + self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled() or is_torch_neuroncore_available() + ) @contextlib.contextmanager def main_process_first(self, local=True, desc="work"):