From 167c188eb540e959d0a79d2a656dbb37d4627633 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Wed, 6 Mar 2024 22:45:11 -0800 Subject: [PATCH 1/3] use torch_device --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index cb224c3c6a9d..4d1546729cbf 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1082,7 +1082,7 @@ def test_model_parallel_beam_search(self): model = model_class(config).eval() with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) - new_model = model_class.from_pretrained(tmp_dir, device_map="auto") + new_model = model_class.from_pretrained(tmp_dir, device_map=torch_device) new_model.generate( input_ids, From c5e17af1c1d8cbf0bd47ab0fd9d48cf3e4723a68 Mon Sep 17 00:00:00 2001 From: "Lin, Fanli" Date: Fri, 8 Mar 2024 02:03:41 -0800 Subject: [PATCH 2/3] skip for XPU --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 4d1546729cbf..69d85ad8f865 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1074,6 +1074,9 @@ def test_beam_search_generate_dict_outputs_use_cache(self): @require_torch_multi_accelerator def test_model_parallel_beam_search(self): for model_class in self.all_generative_model_classes: + if torch_device == "xpu": + return unittest.skip("device_map='auto' does not work with XPU devices") + if model_class._no_split_modules is None: continue @@ -1082,7 +1085,7 @@ def test_model_parallel_beam_search(self): model = model_class(config).eval() with tempfile.TemporaryDirectory() as tmp_dir: model.cpu().save_pretrained(tmp_dir) - new_model = model_class.from_pretrained(tmp_dir, device_map=torch_device) + new_model = model_class.from_pretrained(tmp_dir, device_map="auto") new_model.generate( input_ids, From 5e7f97907377cc695f1ee79c0736d683ad47239d Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 8 Mar 2024 18:19:47 +0800 Subject: [PATCH 3/3] Update tests/generation/test_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1b2009d0ffbd..f0676af638d2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1073,7 +1073,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self): @require_torch_multi_accelerator def test_model_parallel_beam_search(self): for model_class in self.all_generative_model_classes: - if torch_device == "xpu": + if "xpu" in torch_device: return unittest.skip("device_map='auto' does not work with XPU devices") if model_class._no_split_modules is None: