diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 49b8ba3953..7a6954c9f8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -372,11 +372,6 @@ def prepare_n_gradient_checkpoints( pass -# Unsloth only works on NVIDIA GPUs for now -device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," -device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now -device = f"cuda:{device if device.isdigit() else '0'}" - class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): """ Saves VRAM by smartly offloading to RAM. @@ -398,7 +393,7 @@ def forward(ctx, forward_function, hidden_states, *args): @torch.cuda.amp.custom_bwd def backward(ctx, dY): (hidden_states,) = ctx.saved_tensors - hidden_states = hidden_states.to(device, non_blocking = True).detach() + hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach() hidden_states.requires_grad = True with torch.enable_grad(): (output,) = ctx.forward_function(hidden_states, *ctx.args) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 0cc047d214..99374891ab 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -38,9 +38,6 @@ GemmaFlashAttention2 = GemmaAttention pass -import os -device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," -device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now torch_nn_functional_gelu = torch.nn.functional.gelu def fast_geglu_inference(self, X): @@ -48,7 +45,7 @@ def fast_geglu_inference(self, X): # up = self.up_proj(X) bsz, _, hd = X.shape # mlp_size = self.config.intermediate_size - # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = device) + # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) @@ -75,7 +72,7 @@ def GemmaDecoderLayer_fast_forward( *args, **kwargs, ): if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: - out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = device) + out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0") # Self Attention residual = hidden_states @@ -137,7 +134,7 @@ def GemmaModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = device) + out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -220,8 +217,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((radians_new, radians_new), dim = -1) # We must do RoPE in float32! - cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) - sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) + cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype) self.register_buffer("cos_cached", cos, persistent = False) self.register_buffer("sin_cached", sin, persistent = False) pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9db7fcf2db..2d8e6a0748 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -74,11 +74,6 @@ def original_apply_o(self, X): return O pass -import os # Unsloth only works on NVIDIA GPUs for now -device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," -device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now -device = f"cuda:{device if device.isdigit() else '0'}" - from math import sqrt as math_sqrt KV_CACHE_INCREMENT = 256 # KV Cache update size torch_nn_functional_softmax = torch.nn.functional.softmax @@ -136,15 +131,15 @@ def LlamaAttention_fast_forward_inference( # Prefill phase # if not hasattr(self, "paged_attention"): if do_prefill: - self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device) + self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") self.paged_attention_K = self.paged_attention[:,0] self.paged_attention_V = self.paged_attention[:,1] self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) - self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) - self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) - self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) + self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") + self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 elif kv_seq_len >= self.paged_attention.shape[0]: @@ -174,7 +169,7 @@ def LlamaAttention_fast_forward_inference( Qn *= cos Qn.addcmul_(RH_Q, sin) - RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = device) + RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") RH_K[:,:,:,:h] = Kn[:,:,:,h:] RH_K[:,:,:,h:] = Kn[:,:,:,:h] torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) @@ -236,7 +231,7 @@ def fast_swiglu_inference(self, X): # up = self.up_proj(X) bsz, _, hd = X.shape # mlp_size = self.config.intermediate_size - # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = device) + # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) @@ -526,7 +521,7 @@ def LlamaModel_fast_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype = torch.int32, - device = device, + device = "cuda:0", ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) elif position_ids is not None: @@ -846,11 +841,8 @@ def _CausalLM_fast_forward( if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): - device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," - device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now - device = f"cuda:{device if device.isdigit() else '0'}" # Fixes https://github.com/unslothai/unsloth/issues/10 - self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = device) + self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") pass shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) @@ -1471,7 +1463,7 @@ def get_peft_model( print("Unsloth: Casting embed_tokens to float32") model.model.model.embed_tokens.modules_to_save.default\ - .to(device = device, dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) # [TODO] Move old embed_tokens to CPU - should be disk! @@ -1484,7 +1476,7 @@ def get_peft_model( print("Unsloth: Casting lm_head to float32") model.model.lm_head.modules_to_save.default\ - .to(device = device, dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) # [TODO] Move old lm_head to CPU - should be disk! @@ -1713,7 +1705,7 @@ def get_peft_model( print("Unsloth: Casting embed_tokens to float32") assert(hasattr(model.model.model.embed_tokens, "modules_to_save")) model.model.model.embed_tokens.modules_to_save.default\ - .to(device = device, dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) model.model.model.embed_tokens.modules_to_save.default.requires_grad_(True) pass @@ -1721,7 +1713,7 @@ def get_peft_model( print("Unsloth: Casting lm_head to float32") assert(hasattr(model.model.lm_head, "modules_to_save")) model.model.lm_head.modules_to_save.default\ - .to(device = device, dtype = torch.float32, non_blocking = True) + .to(device = "cuda:0", dtype = torch.float32, non_blocking = True) model.model.lm_head.modules_to_save.default.requires_grad_(True) pass @@ -1902,10 +1894,7 @@ def patch_peft_model( # Patch cross entropy loss labels # Fixes https://github.com/unslothai/unsloth/issues/10 max_seq_length = model.max_seq_length - device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," - device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now - device = f"cuda:{device if device.isdigit() else '0'}" - extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = device) + extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0") model.model.extra_ignored_labels = extra_ignored_labels internal_model = model while hasattr(internal_model, "model"): diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 832189beea..d8bd85d478 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -239,11 +239,8 @@ def MistralForCausalLM_fast_forward( if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): - device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," - device = device_ids[:device_ids.find(',')] # Unsloth only works on NVIDIA GPUs for now - device = f"cuda:{device if device.isdigit() else '0'}" # Fixes https://github.com/unslothai/unsloth/issues/10 - self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = device) + self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0") pass shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))