Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,13 +1005,21 @@ def pipeline(
model_kwargs["device_map"] = device_map
if torch_dtype is not None:
if "torch_dtype" in model_kwargs:
raise ValueError(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs["torch_dtype"] = torch_dtype
# If the user did not explicitly provide `torch_dtype` (i.e. the function default "auto" is still
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
# raising. This prevents false positives like providing `torch_dtype` only via `model_kwargs` while the
# top-level argument keeps its default value "auto".
if torch_dtype == "auto":
torch_dtype = None
else:
raise ValueError(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if torch_dtype is not None:
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs["torch_dtype"] = torch_dtype

model_name = model if isinstance(model, str) else None

Expand Down