-
Notifications
You must be signed in to change notification settings - Fork 0
fix: use backend device type in GGUF merge path #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
91cde98
1b90f79
93f1a8c
2564f39
d631837
35dc451
e08c1df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| from .peft_utils import get_lora_layer_modules | ||
| from .utils import _get_dtype | ||
| from .hf_utils import dtype_from_config | ||
| from .device_type import DEVICE_TYPE, DEVICE_TYPE_TORCH, device_empty_cache, device_synchronize | ||
| from .temporary_patches.common import UNSLOTH_ENABLE_LOGGING, logger | ||
| from collections import defaultdict | ||
|
|
||
|
|
@@ -156,11 +157,20 @@ def create_huggingface_repo( | |
| import os, shutil, re, functools | ||
|
|
||
|
|
||
| def _active_merge_device(W): | ||
| if getattr(W.device, "type", None) == DEVICE_TYPE_TORCH: | ||
| return W.device | ||
| if W.device.index is None: | ||
| return torch.device(DEVICE_TYPE_TORCH) | ||
| return torch.device(DEVICE_TYPE_TORCH, W.device.index) | ||
| pass | ||
|
Comment on lines
+160
to
+166
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| def _merge_lora(W, lora_stats, name): | ||
| if lora_stats.lora_A is None or lora_stats.lora_B is None: return W | ||
| W = W.to("cuda", dtype = torch.float32, non_blocking = True) | ||
| lora_B = lora_stats.lora_B.to("cuda", dtype = torch.float32, non_blocking = True) | ||
| lora_A = lora_stats.lora_A.to("cuda", dtype = torch.float32, non_blocking = True) | ||
| device = _active_merge_device(W) | ||
| W = W.to(device, dtype = torch.float32, non_blocking = True) | ||
| lora_B = lora_stats.lora_B.to(device, dtype = torch.float32, non_blocking = True) | ||
| lora_A = lora_stats.lora_A.to(device, dtype = torch.float32, non_blocking = True) | ||
| # Handle vocab resize: LoRA may have more rows than base safetensors weight | ||
| if lora_B.shape[0] != W.shape[0]: | ||
| new_size = lora_B.shape[0] | ||
|
|
@@ -623,9 +633,8 @@ def _merge_and_overwrite_lora( | |
| logger.info(f"[merge_debug] First shard key example: {key}") | ||
|
|
||
| # FORCE memory cleanup before processing each tensor | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.synchronize() | ||
| device_empty_cache() | ||
| device_synchronize() | ||
|
|
||
| # ---------- Special handling for MoE stacked expert params ---------- | ||
| # gate_up_proj is stored fused in the model but sharded as gate_proj & up_proj per expert on disk. | ||
|
|
@@ -772,7 +781,7 @@ def _merge_and_overwrite_lora( | |
| ) | ||
|
|
||
| del W | ||
| torch.cuda.empty_cache() | ||
| device_empty_cache() | ||
| pass | ||
| # Success! Direct overwrite completed | ||
| pass | ||
|
|
@@ -800,8 +809,7 @@ def _merge_and_overwrite_lora( | |
| del tensors | ||
| else: | ||
| gc.collect() | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| device_empty_cache() | ||
|
|
||
| max_retries = 5 | ||
| base_delay = 0.2 # seconds | ||
|
|
@@ -825,8 +833,7 @@ def _merge_and_overwrite_lora( | |
| # Drop mmap refs before os.replace | ||
| del tensors | ||
| gc.collect() | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| device_empty_cache() | ||
|
|
||
| for attempt in range(max_retries): | ||
| try: | ||
|
|
@@ -877,8 +884,7 @@ def _merge_and_overwrite_lora( | |
| except OSError: | ||
| pass | ||
|
|
||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| device_empty_cache() | ||
| return count, safetensor_keys_seen | ||
|
|
||
| except RuntimeError: | ||
|
|
@@ -933,7 +939,7 @@ def _merge_moe_gate_expert(gate_W, lora_stats, expert_idx, num_experts, output_d | |
| # gate_proj corresponds to first half of A | ||
| gate_a = a_slice[:, :inter_dim] # (r, I) | ||
|
|
||
| device = gate_W.device if gate_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| device = _active_merge_device(gate_W) | ||
| gate_delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ gate_a.to(device, dtype = torch.float32, non_blocking = True) | ||
|
|
||
| gate_merged = gate_W.to(device, dtype = torch.float32, non_blocking = True) | ||
|
|
@@ -976,7 +982,7 @@ def _merge_moe_up_expert(up_W, lora_stats, expert_idx, num_experts, output_dtype | |
| # up_proj corresponds to second half of A | ||
| up_a = a_slice[:, inter_dim:] # (r, I) | ||
|
|
||
| device = up_W.device if up_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| device = _active_merge_device(up_W) | ||
| up_delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ up_a.to(device, dtype = torch.float32, non_blocking = True) | ||
|
|
||
| up_merged = up_W.to(device, dtype = torch.float32, non_blocking = True) | ||
|
|
@@ -1017,7 +1023,7 @@ def _merge_moe_down_proj_expert(down_W, lora_stats, expert_idx, num_experts, out | |
| a_slice = lora_stats.lora_A[start:end, :] # (r, H_out) | ||
| b_slice = lora_stats.lora_B[:, start:end] # (I_in, r) | ||
|
|
||
| device = down_W.device if down_W.is_cuda else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| device = _active_merge_device(down_W) | ||
| delta = b_slice.to(device, dtype = torch.float32, non_blocking = True) @ a_slice.to(device, dtype = torch.float32, non_blocking = True) | ||
| merged = down_W.to(device, dtype = torch.float32, non_blocking = True) | ||
| merged = merged.add(delta.transpose(0, 1), alpha = lora_stats.alpha) | ||
|
|
@@ -1322,11 +1328,7 @@ def _merge_moe_fused_gate_up_expert(gate_up_W, lora_stats, output_dtype, is_tran | |
| else: | ||
| return gate_up_W | ||
|
|
||
| device = ( | ||
| gate_up_W.device | ||
| if gate_up_W.is_cuda | ||
| else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| ) | ||
| device = _active_merge_device(gate_up_W) | ||
| gate_up_merged = gate_up_W.to(device, dtype=torch.float32, non_blocking=True) | ||
|
|
||
| for expert_idx in range(num_experts): | ||
|
|
@@ -1378,11 +1380,7 @@ def _merge_moe_fused_down_proj_expert(down_W, lora_stats, output_dtype, is_trans | |
| else: | ||
| return down_W | ||
|
|
||
| device = ( | ||
| down_W.device | ||
| if down_W.is_cuda | ||
| else ("cuda" if torch.cuda.is_available() else "cpu") | ||
| ) | ||
| device = _active_merge_device(down_W) | ||
| down_merged = down_W.to(device, dtype=torch.float32, non_blocking=True) | ||
|
|
||
| for expert_idx in range(num_experts): | ||
|
|
@@ -1440,9 +1438,8 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp | |
| continue | ||
|
|
||
| # FORCE memory cleanup before processing each tensor | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.synchronize() | ||
| device_empty_cache() | ||
| device_synchronize() | ||
|
|
||
| W = None | ||
| output_key = key | ||
|
|
@@ -1463,9 +1460,8 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp | |
|
|
||
| blocks_tensor, scales_tensor = file.get_tensor(key), file.get_tensor(scales_key) | ||
|
|
||
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() # Wait for previous operations to complete | ||
| torch.cuda.empty_cache() | ||
| device_synchronize() | ||
| device_empty_cache() | ||
|
|
||
| # Determine optimal device and chunk size for mxfp4 dequantization | ||
| device_type, device_id, rows_per_chunk = _choose_mxfp4_processing_strategy( | ||
|
|
@@ -1547,7 +1543,7 @@ def _merge_and_overwrite_lora_mxfp4(save_directory, filename, lora_weights, outp | |
| tensors[output_key] = W | ||
|
|
||
| # Free up VRAM after each merge | ||
| torch.cuda.empty_cache() | ||
| device_empty_cache() | ||
|
|
||
| # CRITICAL: Force cleanup to release file handles on Windows | ||
| if os.name == 'nt': | ||
|
|
@@ -2108,7 +2104,7 @@ def upload_items(filename = None): | |
|
|
||
| # Step 3: Conditional index handling | ||
| import subprocess | ||
| is_t4 = "Tesla T4" in str(torch.cuda.get_device_name(0)) | ||
| is_t4 = DEVICE_TYPE == "cuda" and "Tesla T4" in torch.cuda.get_device_name(0) | ||
| needs_splitting = should_split_shards(is_t4, config, safetensors_list, max_size_in_bytes) if save_method == "merged_16bit" else False | ||
| _hf_cache_dir = _get_hf_cache_dir() | ||
| copied_all_from_cache = False | ||
|
|
@@ -2265,7 +2261,7 @@ def upload_items(filename = None): | |
| ) | ||
| n_saved_modules += merged_count | ||
| safetensor_keys_seen.update(shard_keys) | ||
| torch.cuda.empty_cache() | ||
| device_empty_cache() | ||
|
|
||
| file_path = os.path.join(save_directory, filename) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_active_merge_devicehelper can be simplified. Sincetorch.deviceobjects always havetypeandindexattributes, and the constructor handlesindex=Nonegracefully (defaulting to the current device), the logic can be more concise while remaining safe.