Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion unsloth_zoo/device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"ALLOW_PREQUANTIZED_MODELS",
"ALLOW_BITSANDBYTES",
"device_synchronize",
"device_empty_cache",
"device_is_bf16_supported",
]

import torch
Expand Down Expand Up @@ -271,5 +273,35 @@ def device_synchronize():
torch.cuda.synchronize()
elif DEVICE_TYPE == "xpu":
if hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.synchronize()
if hasattr(torch.xpu, "synchronize"):
torch.xpu.synchronize()
pass

def device_empty_cache():
"""
Empty the active device cache (CUDA, XPU, or HIP).
Cross-platform replacement for torch.cuda.empty_cache().
"""
if DEVICE_TYPE in ("cuda", "hip"):
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif DEVICE_TYPE == "xpu":
if hasattr(torch, "xpu") and torch.xpu.is_available():
if hasattr(torch.xpu, "empty_cache"):
torch.xpu.empty_cache()
pass

def device_is_bf16_supported():
"""
Whether the active device (CUDA, XPU, or HIP) supports bfloat16.
Cross-platform replacement for torch.cuda.is_bf16_supported().
"""
if DEVICE_TYPE in ("cuda", "hip"):
if torch.cuda.is_available():
return torch.cuda.is_bf16_supported()
elif DEVICE_TYPE == "xpu":
if hasattr(torch, "xpu") and torch.xpu.is_available():
if hasattr(torch.xpu, "is_bf16_supported"):
return torch.xpu.is_bf16_supported()
return False
pass
7 changes: 4 additions & 3 deletions unsloth_zoo/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import torch
from pathlib import Path
import psutil
from .device_type import device_is_bf16_supported

# Get a logger instance
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1308,11 +1309,11 @@ def convert_to_gguf(
base_name = model_name[:-5]
text_output = model_name
# Fix: mmproj should always include dtype since it's not quantized
mmproj_dtype = model_dtype if model_dtype else ("bf16" if torch.cuda.is_bf16_supported() else "f16")
mmproj_dtype = model_dtype if model_dtype else ("bf16" if device_is_bf16_supported() else "f16")
mmproj_output = f"{base_name}.{mmproj_dtype.upper()}-mmproj.gguf"
else:
text_output = f"{model_name}.{quantization_type.upper()}.gguf"
mmproj_dtype = model_dtype if model_dtype else ("bf16" if torch.cuda.is_bf16_supported() else "f16")
mmproj_dtype = model_dtype if model_dtype else ("bf16" if device_is_bf16_supported() else "f16")
mmproj_output = f"{model_name}.{mmproj_dtype.upper()}-mmproj.gguf"

# Text model conversion
Expand All @@ -1332,7 +1333,7 @@ def convert_to_gguf(
# Vision projector conversion
mmproj_args = {
"--outfile" : mmproj_output,
"--outtype" : model_dtype if model_dtype else "bf16" if torch.cuda.is_bf16_supported() else "f16",
"--outtype" : model_dtype if model_dtype else "bf16" if device_is_bf16_supported() else "f16",
"--mmproj" : "",
"--split-max-size" : max_shard_size,
}
Expand Down
66 changes: 31 additions & 35 deletions unsloth_zoo/saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .peft_utils import get_lora_layer_modules
from .utils import _get_dtype
from .hf_utils import dtype_from_config
from .device_type import DEVICE_TYPE, DEVICE_TYPE_TORCH, device_empty_cache, device_synchronize
from .temporary_patches.common import UNSLOTH_ENABLE_LOGGING, logger
from collections import defaultdict

Expand Down Expand Up @@ -156,11 +157,20 @@ def create_huggingface_repo(
import os, shutil, re, functools


def _active_merge_device(W):
if getattr(W.device, "type", None) == DEVICE_TYPE_TORCH:
return W.device
if W.device.index is None:
return torch.device(DEVICE_TYPE_TORCH)
return torch.device(DEVICE_TYPE_TORCH, W.device.index)
Comment on lines +161 to +165

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The _active_merge_device helper can be simplified. Since torch.device objects always have type and index attributes, and the constructor handles index=None gracefully (defaulting to the current device), the logic can be more concise while remaining safe.

Suggested change
if getattr(W.device, "type", None) == DEVICE_TYPE_TORCH:
return W.device
if W.device.index is None:
return torch.device(DEVICE_TYPE_TORCH)
return torch.device(DEVICE_TYPE_TORCH, W.device.index)
def _active_merge_device(W):
if W.device.type == DEVICE_TYPE_TORCH:
return W.device
return torch.device(DEVICE_TYPE_TORCH, index = W.device.index)
pass

pass
Comment on lines +160 to +166

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

The _active_merge_device helper is a great addition for cross-platform support. However, consider that W.device.type on ROCm/HIP systems might sometimes be reported as "hip" depending on the PyTorch version, while DEVICE_TYPE_TORCH is explicitly set to "cuda" for HIP in device_type.py. While torch.device("cuda") works on HIP, the direct string comparison in line 161 might fail to recognize an existing HIP device as the target, causing an unnecessary torch.device object creation and potential move (though to() is usually a no-op if the device matches). Given the current DEVICE_TYPE_TORCH logic, this is likely safe, but worth noting if "hip" device types start appearing in W.device.type.


def _merge_lora(W, lora_stats, name):
if lora_stats.lora_A is None or lora_stats.lora_B is None: return W
W = W.to("cuda", dtype = torch.float32, non_blocking = True)
lora_B = lora_stats.lora_B.to("cuda", dtype = torch.float32, non_blocking = True)
lora_A = lora_stats.lora_A.to("cuda", dtype = torch.float32, non_blocking = True)
device = _active_merge_device(W)
W = W.to(device, dtype = torch.float32, non_blocking = True)
lora_B = lora_stats.lora_B.to(device, dtype = torch.float32, non_blocking = True)
lora_A = lora_stats.lora_A.to(device, dtype = torch.float32, non_blocking = True)
# Handle vocab resize: LoRA may have more rows than base safetensors weight
if lora_B.shape[0] != W.shape[0]:
new_size = lora_B.shape[0]
Expand Down Expand Up @@ -623,9 +633,8 @@ def _merge_and_overwrite_lora(
logger.info(f"[merge_debug] First shard key example: {key}")

# FORCE memory cleanup before processing each tensor
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
device_empty_cache()
device_synchronize()

# ---------- Special handling for MoE stacked expert params ----------
# gate_up_proj is stored fused in the model but sharded as gate_proj & up_proj per expert on disk.
Expand Down Expand Up @@ -772,7 +781,7 @@ def _merge_and_overwrite_lora(
)

del W
torch.cuda.empty_cache()
device_empty_cache()
pass
# Success! Direct overwrite completed
pass
Expand Down Expand Up @@ -800,8 +809,7 @@ def _merge_and_overwrite_lora(
del tensors
else:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
device_empty_cache()

max_retries = 5
base_delay = 0.2 # seconds
Expand All @@ -825,8 +833,7 @@ def _merge_and_overwrite_lora(
# Drop mmap refs before os.replace
del tensors
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
device_empty_cache()

for attempt in range(max_retries):
try:
Expand Down Expand Up @@ -877,8 +884,7 @@ def _merge_and_overwrite_lora(
except OSError:
pass

if torch.cuda.is_available():
torch.cuda.empty_cache()
device_empty_cache()
return count, safetensor_keys_seen

except RuntimeError:
Expand Down Expand Up @@ -933,7 +939,7 @@ def _merge_moe_gate_expert(gate_W, lora_stats, expert_idx, num_experts, output_d
# gate_proj corresponds to first half of A
gate_a = a_slice[:, :inter_dim] # (r, I)

device = gate_W.device if gate_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu")
device = _active_merge_device(gate_W)
gate_delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ gate_a.to(device, dtype = torch.float32, non_blocking = True)

gate_merged = gate_W.to(device, dtype = torch.float32, non_blocking = True)
Expand Down Expand Up @@ -976,7 +982,7 @@ def _merge_moe_up_expert(up_W, lora_stats, expert_idx, num_experts, output_dtype
# up_proj corresponds to second half of A
up_a = a_slice[:, inter_dim:] # (r, I)

device = up_W.device if up_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu")
device = _active_merge_device(up_W)
up_delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ up_a.to(device, dtype = torch.float32, non_blocking = True)

up_merged = up_W.to(device, dtype = torch.float32, non_blocking = True)
Expand Down Expand Up @@ -1017,7 +1023,7 @@ def _merge_moe_down_proj_expert(down_W, lora_stats, expert_idx, num_experts, out
a_slice = lora_stats.lora_A[start:end, :] # (r, H_out)
b_slice = lora_stats.lora_B[:, start:end] # (I_in, r)

device = down_W.device if down_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu")
device = _active_merge_device(down_W)
delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ a_slice.to(device, dtype = torch.float32, non_blocking = True)
merged = down_W.to(device, dtype = torch.float32, non_blocking = True)
merged = merged.add(delta.transpose(0, 1), alpha = lora_stats.alpha)
Expand Down Expand Up @@ -1322,11 +1328,7 @@ def _merge_moe_fused_gate_up_expert(gate_up_W, lora_stats, output_dtype, is_tran
else:
return gate_up_W

device = (
gate_up_W.device
if gate_up_W.is_cuda
else ("cuda" if torch.cuda.is_available() else "cpu")
)
device = _active_merge_device(gate_up_W)
gate_up_merged = gate_up_W.to(device, dtype=torch.float32, non_blocking=True)

for expert_idx in range(num_experts):
Expand Down Expand Up @@ -1378,11 +1380,7 @@ def _merge_moe_fused_down_proj_expert(down_W, lora_stats, output_dtype, is_trans
else:
return down_W

device = (
down_W.device
if down_W.is_cuda
else ("cuda" if torch.cuda.is_available() else "cpu")
)
device = _active_merge_device(down_W)
down_merged = down_W.to(device, dtype=torch.float32, non_blocking=True)

for expert_idx in range(num_experts):
Expand Down Expand Up @@ -1440,9 +1438,8 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp
continue

# FORCE memory cleanup before processing each tensor
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
device_empty_cache()
device_synchronize()

W = None
output_key = key
Expand All @@ -1463,9 +1460,8 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp

blocks_tensor, scales_tensor = file.get_tensor(key), file.get_tensor(scales_key)

if torch.cuda.is_available():
torch.cuda.synchronize() # Wait for previous operations to complete
torch.cuda.empty_cache()
device_synchronize()
device_empty_cache()

# Determine optimal device and chunk size for mxfp4 dequantization
device_type, device_id, rows_per_chunk = _choose_mxfp4_processing_strategy(
Expand Down Expand Up @@ -1547,7 +1543,7 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp
tensors[output_key] = W

# Free up VRAM after each merge
torch.cuda.empty_cache()
device_empty_cache()

# CRITICAL: Force cleanup to release file handles on Windows
if os.name == 'nt':
Expand Down Expand Up @@ -2108,7 +2104,7 @@ def upload_items(filename = None):

# Step 3: Conditional index handling
import subprocess
is_t4 = "Tesla T4" in str(torch.cuda.get_device_name(0))
is_t4 = DEVICE_TYPE == "cuda" and "Tesla T4" in torch.cuda.get_device_name(0)
needs_splitting = should_split_shards(is_t4, config, safetensors_list, max_size_in_bytes) if save_method == "merged_16bit" else False
_hf_cache_dir = _get_hf_cache_dir()
copied_all_from_cache = False
Expand Down Expand Up @@ -2265,7 +2261,7 @@ def upload_items(filename = None):
)
n_saved_modules += merged_count
safetensor_keys_seen.update(shard_keys)
torch.cuda.empty_cache()
device_empty_cache()

file_path = os.path.join(save_directory, filename)

Expand Down