diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8f7849ea970b..f0676af638d2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1073,6 +1073,9 @@ def test_beam_search_generate_dict_outputs_use_cache(self): @require_torch_multi_accelerator def test_model_parallel_beam_search(self): for model_class in self.all_generative_model_classes: + if "xpu" in torch_device: + return unittest.skip("device_map='auto' does not work with XPU devices") + if model_class._no_split_modules is None: continue