diff --git a/unsloth_zoo/device_type.py b/unsloth_zoo/device_type.py index ff319eaa4..e677cb770 100644 --- a/unsloth_zoo/device_type.py +++ b/unsloth_zoo/device_type.py @@ -21,6 +21,8 @@ "ALLOW_PREQUANTIZED_MODELS", "ALLOW_BITSANDBYTES", "device_synchronize", + "device_empty_cache", + "device_is_bf16_supported", ] import torch @@ -271,5 +273,35 @@ def device_synchronize(): torch.cuda.synchronize() elif DEVICE_TYPE == "xpu": if hasattr(torch, "xpu") and torch.xpu.is_available(): - torch.xpu.synchronize() + if hasattr(torch.xpu, "synchronize"): + torch.xpu.synchronize() +pass + +def device_empty_cache(): + """ + Empty the active device cache (CUDA, XPU, or HIP). + Cross-platform replacement for torch.cuda.empty_cache(). + """ + if DEVICE_TYPE in ("cuda", "hip"): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif DEVICE_TYPE == "xpu": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + if hasattr(torch.xpu, "empty_cache"): + torch.xpu.empty_cache() +pass + +def device_is_bf16_supported(): + """ + Whether the active device (CUDA, XPU, or HIP) supports bfloat16. + Cross-platform replacement for torch.cuda.is_bf16_supported(). + """ + if DEVICE_TYPE in ("cuda", "hip"): + if torch.cuda.is_available(): + return torch.cuda.is_bf16_supported() + elif DEVICE_TYPE == "xpu": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + if hasattr(torch.xpu, "is_bf16_supported"): + return torch.xpu.is_bf16_supported() + return False pass diff --git a/unsloth_zoo/llama_cpp.py b/unsloth_zoo/llama_cpp.py index 17bd6a25c..03844d11a 100644 --- a/unsloth_zoo/llama_cpp.py +++ b/unsloth_zoo/llama_cpp.py @@ -45,6 +45,7 @@ import torch from pathlib import Path import psutil +from .device_type import device_is_bf16_supported # Get a logger instance logger = logging.getLogger(__name__) @@ -1308,11 +1309,11 @@ def convert_to_gguf( base_name = model_name[:-5] text_output = model_name # Fix: mmproj should always include dtype since it's not quantized - mmproj_dtype = model_dtype if model_dtype else ("bf16" if torch.cuda.is_bf16_supported() else "f16") + mmproj_dtype = model_dtype if model_dtype else ("bf16" if device_is_bf16_supported() else "f16") mmproj_output = f"{base_name}.{mmproj_dtype.upper()}-mmproj.gguf" else: text_output = f"{model_name}.{quantization_type.upper()}.gguf" - mmproj_dtype = model_dtype if model_dtype else ("bf16" if torch.cuda.is_bf16_supported() else "f16") + mmproj_dtype = model_dtype if model_dtype else ("bf16" if device_is_bf16_supported() else "f16") mmproj_output = f"{model_name}.{mmproj_dtype.upper()}-mmproj.gguf" # Text model conversion @@ -1332,7 +1333,7 @@ def convert_to_gguf( # Vision projector conversion mmproj_args = { "--outfile" : mmproj_output, - "--outtype" : model_dtype if model_dtype else "bf16" if torch.cuda.is_bf16_supported() else "f16", + "--outtype" : model_dtype if model_dtype else "bf16" if device_is_bf16_supported() else "f16", "--mmproj" : "", "--split-max-size" : max_shard_size, } diff --git a/unsloth_zoo/saving_utils.py b/unsloth_zoo/saving_utils.py index e441f6e5b..f1066f1f0 100644 --- a/unsloth_zoo/saving_utils.py +++ b/unsloth_zoo/saving_utils.py @@ -23,6 +23,7 @@ from .peft_utils import get_lora_layer_modules from .utils import _get_dtype from .hf_utils import dtype_from_config +from .device_type import DEVICE_TYPE, DEVICE_TYPE_TORCH, device_empty_cache, device_synchronize from .temporary_patches.common import UNSLOTH_ENABLE_LOGGING, logger from collections import defaultdict @@ -156,11 +157,20 @@ def create_huggingface_repo( import os, shutil, re, functools +def _active_merge_device(W): + if getattr(W.device, "type", None) == DEVICE_TYPE_TORCH: + return W.device + if W.device.index is None: + return torch.device(DEVICE_TYPE_TORCH) + return torch.device(DEVICE_TYPE_TORCH, W.device.index) +pass + def _merge_lora(W, lora_stats, name): if lora_stats.lora_A is None or lora_stats.lora_B is None: return W - W = W.to("cuda", dtype = torch.float32, non_blocking = True) - lora_B = lora_stats.lora_B.to("cuda", dtype = torch.float32, non_blocking = True) - lora_A = lora_stats.lora_A.to("cuda", dtype = torch.float32, non_blocking = True) + device = _active_merge_device(W) + W = W.to(device, dtype = torch.float32, non_blocking = True) + lora_B = lora_stats.lora_B.to(device, dtype = torch.float32, non_blocking = True) + lora_A = lora_stats.lora_A.to(device, dtype = torch.float32, non_blocking = True) # Handle vocab resize: LoRA may have more rows than base safetensors weight if lora_B.shape[0] != W.shape[0]: new_size = lora_B.shape[0] @@ -623,9 +633,8 @@ def _merge_and_overwrite_lora( logger.info(f"[merge_debug] First shard key example: {key}") # FORCE memory cleanup before processing each tensor - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() + device_empty_cache() + device_synchronize() # ---------- Special handling for MoE stacked expert params ---------- # gate_up_proj is stored fused in the model but sharded as gate_proj & up_proj per expert on disk. @@ -772,7 +781,7 @@ def _merge_and_overwrite_lora( ) del W - torch.cuda.empty_cache() + device_empty_cache() pass # Success! Direct overwrite completed pass @@ -800,8 +809,7 @@ def _merge_and_overwrite_lora( del tensors else: gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + device_empty_cache() max_retries = 5 base_delay = 0.2 # seconds @@ -825,8 +833,7 @@ def _merge_and_overwrite_lora( # Drop mmap refs before os.replace del tensors gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + device_empty_cache() for attempt in range(max_retries): try: @@ -877,8 +884,7 @@ def _merge_and_overwrite_lora( except OSError: pass - if torch.cuda.is_available(): - torch.cuda.empty_cache() + device_empty_cache() return count, safetensor_keys_seen except RuntimeError: @@ -933,7 +939,7 @@ def _merge_moe_gate_expert(gate_W, lora_stats, expert_idx, num_experts, output_d # gate_proj corresponds to first half of A gate_a = a_slice[:, :inter_dim] # (r, I) - device = gate_W.device if gate_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu") + device = _active_merge_device(gate_W) gate_delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ gate_a.to(device, dtype = torch.float32, non_blocking = True) gate_merged = gate_W.to(device, dtype = torch.float32, non_blocking = True) @@ -976,7 +982,7 @@ def _merge_moe_up_expert(up_W, lora_stats, expert_idx, num_experts, output_dtype # up_proj corresponds to second half of A up_a = a_slice[:, inter_dim:] # (r, I) - device = up_W.device if up_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu") + device = _active_merge_device(up_W) up_delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ up_a.to(device, dtype = torch.float32, non_blocking = True) up_merged = up_W.to(device, dtype = torch.float32, non_blocking = True) @@ -1017,7 +1023,7 @@ def _merge_moe_down_proj_expert(down_W, lora_stats, expert_idx, num_experts, out a_slice = lora_stats.lora_A[start:end, :] # (r, H_out) b_slice = lora_stats.lora_B[:, start:end] # (I_in, r) - device = down_W.device if down_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu") + device = _active_merge_device(down_W) delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ a_slice.to(device, dtype = torch.float32, non_blocking = True) merged = down_W.to(device, dtype = torch.float32, non_blocking = True) merged = merged.add(delta.transpose(0, 1), alpha = lora_stats.alpha) @@ -1322,11 +1328,7 @@ def _merge_moe_fused_gate_up_expert(gate_up_W, lora_stats, output_dtype, is_tran else: return gate_up_W - device = ( - gate_up_W.device - if gate_up_W.is_cuda - else ("cuda" if torch.cuda.is_available() else "cpu") - ) + device = _active_merge_device(gate_up_W) gate_up_merged = gate_up_W.to(device, dtype=torch.float32, non_blocking=True) for expert_idx in range(num_experts): @@ -1378,11 +1380,7 @@ def _merge_moe_fused_down_proj_expert(down_W, lora_stats, output_dtype, is_trans else: return down_W - device = ( - down_W.device - if down_W.is_cuda - else ("cuda" if torch.cuda.is_available() else "cpu") - ) + device = _active_merge_device(down_W) down_merged = down_W.to(device, dtype=torch.float32, non_blocking=True) for expert_idx in range(num_experts): @@ -1440,9 +1438,8 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp continue # FORCE memory cleanup before processing each tensor - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() + device_empty_cache() + device_synchronize() W = None output_key = key @@ -1463,9 +1460,8 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp blocks_tensor, scales_tensor = file.get_tensor(key), file.get_tensor(scales_key) - if torch.cuda.is_available(): - torch.cuda.synchronize() # Wait for previous operations to complete - torch.cuda.empty_cache() + device_synchronize() + device_empty_cache() # Determine optimal device and chunk size for mxfp4 dequantization device_type, device_id, rows_per_chunk = _choose_mxfp4_processing_strategy( @@ -1547,7 +1543,7 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp tensors[output_key] = W # Free up VRAM after each merge - torch.cuda.empty_cache() + device_empty_cache() # CRITICAL: Force cleanup to release file handles on Windows if os.name == 'nt': @@ -2108,7 +2104,7 @@ def upload_items(filename = None): # Step 3: Conditional index handling import subprocess - is_t4 = "Tesla T4" in str(torch.cuda.get_device_name(0)) + is_t4 = DEVICE_TYPE == "cuda" and "Tesla T4" in torch.cuda.get_device_name(0) needs_splitting = should_split_shards(is_t4, config, safetensors_list, max_size_in_bytes) if save_method == "merged_16bit" else False _hf_cache_dir = _get_hf_cache_dir() copied_all_from_cache = False @@ -2265,7 +2261,7 @@ def upload_items(filename = None): ) n_saved_modules += merged_count safetensor_keys_seen.update(shard_keys) - torch.cuda.empty_cache() + device_empty_cache() file_path = os.path.join(save_directory, filename)