diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a41d43edeb4f..615836e11604 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -370,6 +370,19 @@ def __init__( else: self.is_model_parallel = False + if ( + getattr(model, "hf_device_map", None) is not None + and len([device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]) > 1 + and not self.is_model_parallel + ): + self.is_model_parallel = True + + # warn users + logger.info( + "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" + " to `True` to avoid any unexpected behavior such as device placement mismatching." + ) + # At this stage the model is already loaded if getattr(model, "is_loaded_in_8bit", False): if getattr(model, "_is_int8_training_enabled", False):