From 1d427181cefeddc704cb4f959ffdcfcd13742629 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 15 Mar 2023 13:53:07 -0400 Subject: [PATCH 1/2] Fix regression in pipeline when device=-1 is passed --- src/transformers/pipelines/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 3dca2d33d157..cdb597ef9661 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -769,8 +769,8 @@ def __init__( self.modelcard = modelcard self.framework = framework - if self.framework == "pt" and device is not None: - self.model = self.model.to(device=device) + if self.framework == "pt" and device is not None and not (isinstance(device, int) and device < 0): + self.model.to(device) if device is None: # `accelerate` device map From f8fe7cff4c70bc1524c5499d44c3298444977bbb Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Wed, 15 Mar 2023 14:02:23 -0400 Subject: [PATCH 2/2] Add regression test --- tests/pipelines/test_pipelines_common.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f43f439ac279..1a3ddace2871 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -484,6 +484,14 @@ def add(number, extra=0): outputs = list(dataset) self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]]) + def test_pipeline_negative_device(self): + # To avoid regressing, pipeline used to accept device=-1 + classifier = pipeline("text-generation", "hf-internal-testing/tiny-random-bert", device=-1) + + expected_output = [{"generated_text": ANY(str)}] + actual_output = classifier("Test input.") + self.assertEqual(expected_output, actual_output) + @slow @require_torch def test_load_default_pipelines_pt(self):