diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index a693389355..1b122fc8e1 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -372,6 +372,10 @@ def prepare_n_gradient_checkpoints( pass +# Unsloth only works on NVIDIA GPUs for now +device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," +device = f"cuda:{device_ids[:device_ids.find(',')]}" + class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): """ Saves VRAM by smartly offloading to RAM. @@ -393,7 +397,7 @@ def forward(ctx, forward_function, hidden_states, *args): @torch.cuda.amp.custom_bwd def backward(ctx, dY): (hidden_states,) = ctx.saved_tensors - hidden_states = hidden_states.to("cuda", non_blocking = True).detach() + hidden_states = hidden_states.to(device, non_blocking = True).detach() hidden_states.requires_grad = True with torch.enable_grad(): (output,) = ctx.forward_function(hidden_states, *ctx.args) @@ -457,7 +461,6 @@ def _prepare_backend( # Offloading to disk for modules (lm_head, embed_tokens) -import os import pickle def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"): diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 5dd2a5abd5..985028364b 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -38,6 +38,11 @@ GemmaFlashAttention2 = GemmaAttention pass +# Unsloth currently only works on one GPU +import os +device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," +device = f"cuda:{device_ids[:device_ids.find(',')]}" +# Please obtain a commercial license torch_nn_functional_gelu = torch.nn.functional.gelu def fast_geglu_inference(self, X): @@ -45,7 +50,7 @@ def fast_geglu_inference(self, X): # up = self.up_proj(X) bsz, _, hd = X.shape # mlp_size = self.config.intermediate_size - # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda") + # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = device) gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) @@ -72,7 +77,7 @@ def GemmaDecoderLayer_fast_forward( *args, **kwargs, ): if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: - out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda") + out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = device) # Self Attention residual = hidden_states @@ -134,7 +139,7 @@ def GemmaModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda") + out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = device) input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7cbdcfbda7..9327b1bb45 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -74,6 +74,9 @@ def original_apply_o(self, X): return O pass +import os # Unsloth only works on NVIDIA GPUs for now +device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," +device = f"cuda:{device_ids[:device_ids.find(',')]}" from math import sqrt as math_sqrt KV_CACHE_INCREMENT = 256 # KV Cache update size @@ -132,15 +135,15 @@ def LlamaAttention_fast_forward_inference( # Prefill phase # if not hasattr(self, "paged_attention"): if do_prefill: - self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda") + self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = device) self.paged_attention_K = self.paged_attention[:,0] self.paged_attention_V = self.paged_attention[:,1] self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) - self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda") - self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda") - self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda") - self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda") + self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) + self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 elif kv_seq_len >= self.paged_attention.shape[0]: @@ -170,7 +173,7 @@ def LlamaAttention_fast_forward_inference( Qn *= cos Qn.addcmul_(RH_Q, sin) - RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda") + RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = device) RH_K[:,:,:,:h] = Kn[:,:,:,h:] RH_K[:,:,:,h:] = Kn[:,:,:,:h] torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) @@ -232,7 +235,7 @@ def fast_swiglu_inference(self, X): # up = self.up_proj(X) bsz, _, hd = X.shape # mlp_size = self.config.intermediate_size - # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda") + # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = device) gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0]) up = fast_linear_forward(self. up_proj, X)#, out = temp[1]) @@ -522,7 +525,7 @@ def LlamaModel_fast_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype = torch.int32, - device = "cuda", + device = device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) elif position_ids is not None: @@ -842,8 +845,10 @@ def _CausalLM_fast_forward( if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): + device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," + device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now # Fixes https://github.com/unslothai/unsloth/issues/10 - self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda") + self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = device) pass shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]])) @@ -1822,7 +1827,9 @@ def patch_peft_model( # Patch cross entropy loss labels # Fixes https://github.com/unslothai/unsloth/issues/10 max_seq_length = model.max_seq_length - extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda") + device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," + device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now + extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = device) model.model.extra_ignored_labels = extra_ignored_labels internal_model = model while hasattr(internal_model, "model"): diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 291f0aa502..e147f21568 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -239,8 +239,10 @@ def MistralForCausalLM_fast_forward( if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): + device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + "," + device = f"cuda:{device_ids[:device_ids.find(',')]}" # Unsloth only works on NVIDIA GPUs for now # Fixes https://github.com/unslothai/unsloth/issues/10 - self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda") + self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = device) pass shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))