Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,13 +1005,21 @@ def pipeline(
model_kwargs["device_map"] = device_map
if torch_dtype is not None:
if "torch_dtype" in model_kwargs:
raise ValueError(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs["torch_dtype"] = torch_dtype
# If the user did not explicitly provide `torch_dtype` (i.e. the function default "auto" is still
# present) but a value is supplied inside `model_kwargs`, we silently defer to the latter instead of
# raising. This prevents false positives like providing `torch_dtype` only via `model_kwargs` while the
# top-level argument keeps its default value "auto".
if torch_dtype == "auto":
torch_dtype = None
else:
raise ValueError(
'You cannot use both `pipeline(... torch_dtype=..., model_kwargs={"torch_dtype":...})` as those'
" arguments might conflict, use only one.)"
)
if torch_dtype is not None:
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
model_kwargs["torch_dtype"] = torch_dtype

model_name = model if isinstance(model, str) else None

Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ def test_small_model_pt_token(self):
[
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
},
{
"input_text": "<image> What this is? Assistant: This is",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are sleeping and appear to be comfortable",
"generated_text": "<image> What this is? Assistant: This is a photo of two cats lying on a pink blanket. The cats are facing the camera, and they",
},
],
)
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,11 +441,11 @@ def test_small_model_pt_bloom_accelerate(self):
[{"generated_text": ("This is a test test test test test test")}],
)

# torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602
# torch_dtype will be automatically set to torch.bfloat16 if not provided - check: https://github.com/huggingface/transformers/pull/38882
pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom", device_map="auto", max_new_tokens=5, do_sample=False
)
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32)
self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16)
out = pipe("This is a test")
self.assertEqual(
out,
Expand Down