diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 319c36e7c874..1e68fcc4bd53 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -378,14 +378,17 @@ def __init__( devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] if len(devices) > 1: self.is_model_parallel = True - else: + elif len(devices) == 1: self.is_model_parallel = self.args.device != torch.device(devices[0]) + else: + self.is_model_parallel = False # warn users - logger.info( - "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" - " to `True` to avoid any unexpected behavior such as device placement mismatching." - ) + if self.is_model_parallel: + logger.info( + "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" + " to `True` to avoid any unexpected behavior such as device placement mismatching." + ) # At this stage the model is already loaded if getattr(model, "is_quantized", False):