diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 3dca2d33d157..cdb597ef9661 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -769,8 +769,8 @@ def __init__( self.modelcard = modelcard self.framework = framework - if self.framework == "pt" and device is not None: - self.model = self.model.to(device=device) + if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0): + self.model.to(device) if device is None: # `accelerate` device map diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f43f439ac279..1a3ddace2871 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -484,6 +484,14 @@ def add(number, extra=0): outputs = list(dataset) self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]]) + def test_pipeline_negative_device(self): + # To avoid regressing, pipeline used to accept device=-1 + classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1) + + expected_output = [{"generated_text": ANY(str)}] + actual_output = classifier("Test input.") + self.assertEqual(expected_output, actual_output) + @slow @require_torch def test_load_default_pipelines_pt(self):