diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index a33089547f5a..4712eaba5794 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -693,7 +693,7 @@ def predict(self, X): Reference to the object in charge of parsing supplied pipeline parameters. device (`int`, *optional*, defaults to -1): Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on - the associated CUDA device id. + the associated CUDA device id. You can pass native `torch.device` too. binary_output (`bool`, *optional*, defaults to `False`): Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text. """ @@ -750,7 +750,10 @@ def __init__( self.feature_extractor = feature_extractor self.modelcard = modelcard self.framework = framework - self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") + if is_torch_available() and isinstance(device, torch.device): + self.device = device + else: + self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}") self.binary_output = binary_output # Special handling diff --git a/tests/pipelines/test_pipelines_text_classification.py b/tests/pipelines/test_pipelines_text_classification.py index 39deed9bee55..4cb72be4c2bf 100644 --- a/tests/pipelines/test_pipelines_text_classification.py +++ b/tests/pipelines/test_pipelines_text_classification.py @@ -39,6 +39,20 @@ def test_small_model_pt(self): outputs = text_classifier("This is great !") self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + @require_torch + def test_accepts_torch_device(self): + import torch + + text_classifier = pipeline( + task="text-classification", + model="hf-internal-testing/tiny-random-distilbert", + framework="pt", + device=torch.device("cpu"), + ) + + outputs = text_classifier("This is great !") + self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}]) + @require_tf def test_small_model_tf(self): text_classifier = pipeline(