diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 92bda4f810e0..c0aee8b2db26 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -284,10 +284,10 @@ def test_small_model_pt_bloom_accelerate(self): ], ) - # torch_dtype not necessary + # torch_dtype will be automatically set to float32 if not provided - check: https://github.com/huggingface/transformers/pull/20602 pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device_map="auto") self.assertEqual(pipe.model.device, torch.device(0)) - self.assertEqual(pipe.model.lm_head.weight.dtype, torch.bfloat16) + self.assertEqual(pipe.model.lm_head.weight.dtype, torch.float32) out = pipe("This is a test") self.assertEqual( out,