diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 5c955a3c8d..1c65246b31 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -78,7 +78,7 @@ def calculate_settings(n : int) -> (int, int,): # INTEL GPU specific logic if DEVICE_TYPE == "xpu": # TODO: Changed here after adding XPU BNB support - HAS_XPU_STREAM = False + HAS_XPU_STREAM = True def get_ptr(x: Optional[torch.Tensor]): raise RuntimeError("XPU BNB support is not implemented yet. This function should not be called.") else: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 602f7ee90f..1af631f965 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -142,6 +142,12 @@ import logging logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1) +def get_device_num(): + if DEVICE_TYPE == "xpu": + return torch.xpu.device_count() + else: + return torch.cuda.device_count() + # Ignore logging messages class HideLoggingMessage(logging.Filter): __slots__ = "text", @@ -738,7 +744,7 @@ def get_statistics(): pass pass try: - devices = torch.cuda.device_count() + devices = get_device_num() _get_statistics(f"{devices if devices <= 8 else 9}") except: pass @@ -765,7 +771,7 @@ def get_statistics(): ) exec(BitsAndBytesConfig__init__, globals()) -if torch.cuda.device_count() == 1: +if get_device_num() == 1: from accelerate.utils.dataclasses import DistributedType def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO import accelerate.state diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7d56dac2ec..f08b4762eb 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -85,6 +85,11 @@ HAS_XFORMERS = xformers is not None BlockDiagonalCausalMask = xformers.attn_bias.BlockDiagonalCausalMask if HAS_XFORMERS else None +def clean_gpu_cache(): + if DEVICE_TYPE == "xpu": + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() def original_apply_qkv(self, X): Q = self.q_proj(X) @@ -1752,10 +1757,11 @@ def from_pretrained( if not is_vLLM_available(): print("Unsloth: vLLM is not installed! Will use Unsloth inference!") fast_inference = False - major_version, minor_version = torch.cuda.get_device_capability() - if major_version < 7: - print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!") - fast_inference = False + if DEVICE_TYPE == "cuda": + major_version, minor_version = torch.cuda.get_device_capability() + if major_version < 7: + print("Unsloth: vLLM does not work on older GPUs - will switch to Unsloth inference!") + fast_inference = False if unsloth_vllm_standby and os.environ.get("UNSLOTH_VLLM_STANDBY", "0") == "0": raise RuntimeError("Unsloth: `unsloth_vllm_standby` is True, but environment variable `UNSLOTH_VLLM_STANDBY` is not set to 1!") pass @@ -1779,8 +1785,8 @@ def from_pretrained( num_gpus = torch.xpu.device_count() gpu_stats_snippet = f"Intel Toolkit: {gpu_version}." - # TODO: After adding vLLM support for XPU, changed this - vllm_version = "" + try: vllm_version = f" vLLM: {importlib_version('vllm')}." + except: vllm_version = "" else: raise ValueError(f"Unsloth: Unsupported device type: {DEVICE_TYPE}") @@ -2020,7 +2026,10 @@ def from_pretrained( import gc for _ in range(3): gc.collect() - torch.cuda.empty_cache()""" + if DEVICE_TYPE == "xpu": + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache()""" debug_info = debug_info.split('\n') debug_info = "\n".join([debug_info[0]] + [spaces + x[8:] for x in debug_info[1:]]) @@ -2508,7 +2517,7 @@ def get_peft_model( # Remove old items to save VRAM for _ in range(3): gc.collect() - torch.cuda.empty_cache() + clean_gpu_cache() pass if train_lm_head: @@ -2519,7 +2528,7 @@ def get_peft_model( # Remove old items to save VRAM for _ in range(3): gc.collect() - torch.cuda.empty_cache() + clean_gpu_cache() pass pass @@ -2580,7 +2589,7 @@ def get_peft_model( # Clear deleted GPU items for _ in range(3): gc.collect() - torch.cuda.empty_cache() + clean_gpu_cache() pass # Patch for fast inference @@ -2796,7 +2805,7 @@ def patch_peft_model( # Clear deleted GPU items for _ in range(3): gc.collect() - torch.cuda.empty_cache() + clean_gpu_cache() pass # Patch for fast inference diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index 9b0f4e4aef..bd5c58384e 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -26,6 +26,8 @@ import inspect from collections import defaultdict from unsloth_zoo.rl_replacements import RL_REPLACEMENTS +from unsloth import DEVICE_TYPE + RL_EXTRA_ARGS = defaultdict(list) RL_FUNCTIONS = defaultdict(list) RL_PRE_ITEMS = defaultdict(list) @@ -258,7 +260,7 @@ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep) if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1" - with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): + with torch.amp.autocast(device_type = DEVICE_TYPE, dtype = self._autocast_dtype): # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded logits = model( input_ids = input_ids,