diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index a459ffa84d0e..6e8e7b842ac3 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -215,23 +215,8 @@ def test_multi_gpu_loading(self): self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto" ) - def get_list_devices(model): - list_devices = [] - for _, module in model.named_children(): - if len(list(module.children())) > 0: - list_devices.extend(get_list_devices(module)) - else: - # Do a try except since we can encounter Dropout modules that does not - # have any device set - try: - list_devices.append(next(module.parameters()).device.index) - except BaseException: - continue - return list_devices - - list_devices = get_list_devices(model_parallel) - # Check that we have dispatched the model into 2 separate devices - self.assertTrue((1 in list_devices) and (0 in list_devices)) + # Check correct device map + self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1}) # Check that inference pass works on the model encoded_input = self.tokenizer(self.input_text, return_tensors="pt")