diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 766f2a462a19..ada04c7dbeda 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -450,7 +450,9 @@ def test_small_model_fp16(self): def test_pipeline_accelerate_top_p(self): import torch - pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto", torch_dtype=torch.float16) + pipe = pipeline( + model="hf-internal-testing/tiny-random-bloom", device_map=torch_device, torch_dtype=torch.float16 + ) pipe("This is a test", do_sample=True, top_p=0.5) def test_pipeline_length_setting_warning(self):