Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def predict(self, X):
Reference to the object in charge of parsing supplied pipeline parameters.
device (`int`, *optional*, defaults to -1):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
the associated CUDA device id.
the associated CUDA device id. You can pass native `torch.device` too.
binary_output (`bool`, *optional*, defaults to `False`):
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
"""
Expand Down Expand Up @@ -750,7 +750,10 @@ def __init__(
self.feature_extractor = feature_extractor
self.modelcard = modelcard
self.framework = framework
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
if is_torch_available() and isinstance(device, torch.device):
self.device = device
else:
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
self.binary_output = binary_output

# Special handling
Expand Down
14 changes: 14 additions & 0 deletions tests/pipelines/test_pipelines_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def test_small_model_pt(self):
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])

@require_torch
def test_accepts_torch_device(self):
import torch

text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch.device("cpu"),
)

outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])

@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(
Expand Down