diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f883d466f0..8684991431 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1473,7 +1473,7 @@ def get_statistics(local_files_only = False): ) exec(BitsAndBytesConfig__init__, globals()) -if DEVICE_COUNT == 1: +if DEVICE_COUNT == 1 and int(os.environ.get("WORLD_SIZE", "1")) <= 1: from accelerate.utils.dataclasses import DistributedType def _prepare_backend(self, *args, **kwargs):