diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 0911b24912be..153e8e9f6b42 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -775,12 +775,20 @@ def __init__( self.modelcard = modelcard self.framework = framework + # `accelerate` device map + hf_device_map = getattr(self.model, "hf_device_map", None) + + if hf_device_map is not None and device is not None: + raise ValueError( + "The model has been loaded with `accelerate` and therefore cannot be moved to a specific device. Please " + "discard the `device` argument when creating your pipeline object." + ) + + # We shouldn't call `model.to()` for models loaded with accelerate if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0): self.model.to(device) if device is None: - # `accelerate` device map - hf_device_map = getattr(self.model, "hf_device_map", None) if hf_device_map is not None: # Take the first device used by `accelerate`. device = next(iter(hf_device_map.values()))