From 2b95cc7c4dce99136fb4451087df852fb051465f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 2 Jul 2025 13:22:34 +0000 Subject: [PATCH 1/8] Move tensors to right devices --- unsloth/models/llama.py | 16 ++++++++++++++-- unsloth/models/mistral.py | 5 +++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d0ff413925..44961204ab 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -209,6 +209,8 @@ def LlamaAttention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ + if position_ids is not None: + position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -269,8 +271,8 @@ def LlamaAttention_fast_forward_inference( # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q @@ -1019,6 +1021,16 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + print(f'Inference through layer {idx}') + decoder_device = decoder_layer.self_attn.q_proj.weight.device + if X.device != decoder_device: + X = X.to(decoder_device) + if residual.device != decoder_device: + residual = residual.to(decoder_device) + if temp_gate.device != decoder_device: + temp_gate = temp_gate.to(decoder_device) + if temp_up.device != decoder_device: + temp_up = temp_up.to(decoder_device) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index ef71c89c4a..78a8a792b9 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -89,6 +89,11 @@ def MistralAttention_fast_forward( Q, K = fast_rope_embedding(Q, K, cos, sin) else: cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + if cos.device != Q.device: + cos = cos.to(Q.device) + sin = sin.to(Q.device) + if Q.device != K.device: + raise ValueError("Q and K must be on the same device") Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass From 2ca7875433a490daa079f9b4f5379668c39e8001 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 2 Jul 2025 15:29:52 +0000 Subject: [PATCH 2/8] fix multi gpu for non mistral models --- unsloth/models/cohere.py | 6 ++++-- unsloth/models/falcon_h1.py | 6 ++++-- unsloth/models/gemma2.py | 4 ++-- unsloth/models/llama.py | 3 ++- unsloth/models/mistral.py | 5 +++-- unsloth/models/qwen3.py | 6 ++++-- 6 files changed, 19 insertions(+), 11 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 25704301d1..1d32f62c86 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -254,6 +254,8 @@ def CohereAttention_fast_forward_inference( do_prefill = False, attention_mask = None, ): + if position_ids is not None: + position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -321,8 +323,8 @@ def CohereAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 8978e4db0d..7889f673b9 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -116,6 +116,8 @@ def FalconH1Attention_fast_forward( cos, sin = rotary_emb.get_cached(kv_seq_len) else: cos, sin = rotary_emb(V, seq_len = kv_seq_len) + cos = cos.to(Q.device) + sin = sin.to(Q.device) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -281,8 +283,8 @@ def FalconH1Attention_fast_forward_inference( # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 23b91ff6f3..94635beb4c 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -307,8 +307,8 @@ def Gemma2Attention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1) - sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1) + cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1).to(Qn.device) + sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 44961204ab..ca009f9fae 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1021,7 +1021,6 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - print(f'Inference through layer {idx}') decoder_device = decoder_layer.self_attn.q_proj.weight.device if X.device != decoder_device: X = X.to(decoder_device) @@ -1031,6 +1030,8 @@ def LlamaModel_fast_forward_inference_custom( temp_gate = temp_gate.to(decoder_device) if temp_up.device != decoder_device: temp_up = temp_up.to(decoder_device) + if position_ids.device != decoder_device: + position_ids = position_ids.to(decoder_device) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 78a8a792b9..f6538c65ac 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -90,10 +90,11 @@ def MistralAttention_fast_forward( else: cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) if cos.device != Q.device: + # without this, even though V is on GPU0, by the time we reach the foward function + # the argument x is on GPU1, and hence cos, sin end up on GPU1 + # this is a hack to get around this quirk cos = cos.to(Q.device) sin = sin.to(Q.device) - if Q.device != K.device: - raise ValueError("Q and K must be on the same device") Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 83c9dbea0a..21e52548f8 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -232,6 +232,8 @@ def Qwen3Attention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ + if position_ids is not None: + position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -298,8 +300,8 @@ def Qwen3Attention_fast_forward_inference( # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q From 03c57c1c47c560cd6f56c58cf0843be0e8a65402 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 3 Jul 2025 15:57:29 +0000 Subject: [PATCH 3/8] multi GPU RoPE for gemma2 --- unsloth/models/gemma.py | 33 ++++++++++++++++++++++----------- unsloth/models/gemma2.py | 24 ++++++++++++++++++------ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 8ad1c7e62d..9138baf78c 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -224,9 +224,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = {} + self.multi_gpu_sin_cached = {} # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device in range(torch.cuda.device_count()): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -245,32 +252,36 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((radians_new, radians_new), dim = -1) # We must do RoPE in float32! - cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype) - sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype) - self.register_buffer("cos_cached", cos, persistent = False) - self.register_buffer("sin_cached", sin, persistent = False) + cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype=x.dtype), + self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype=x.dtype), ) pass - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = math.ceil(seq_len / 8192) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype) + for device in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype) pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 94635beb4c..ebb06de8e4 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -114,11 +114,11 @@ def Gemma2Attention_fast_forward( kv_seq_len += past_key_value[0].shape[-2] if position_ids is None: - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached + cos = self.rotary_emb.multi_gpu_cos_cached[Q.device] + sin = self.rotary_emb.multi_gpu_sin_cached[Q.device] Q, K = fast_rope_embedding(Q, K, cos, sin) else: - cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass @@ -307,8 +307,9 @@ def Gemma2Attention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1).to(Qn.device) - sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q @@ -385,7 +386,7 @@ def Gemma2Model_fast_forward_inference( past_key_values, position_ids, attention_mask = None, -): +): out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) @@ -422,6 +423,17 @@ def Gemma2Model_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + # For pipeline parallelism, we need to move all tensors to the same device + # note that this movement is once per GPU in PP + layer_device = decoder_layer.self_attn.q_proj.weight.device + if hidden_states.device != layer_device: + hidden_states = hidden_states.to(layer_device) + if out_weight.device != layer_device: + out_weight = out_weight.to(layer_device) + if position_ids.device != layer_device: + position_ids = position_ids.to(layer_device) + pass + use_sliding_window = idx % 2 == 0 residual = hidden_states From a937baa2f3da1228d8691d587c1aac1c337a6372 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 3 Jul 2025 17:22:53 +0000 Subject: [PATCH 4/8] Finish up multi GPU inference --- unsloth/models/_utils.py | 20 ++++ unsloth/models/cohere.py | 10 +- unsloth/models/falcon_h1.py | 12 +-- unsloth/models/gemma.py | 17 +++- unsloth/models/gemma2.py | 10 +- unsloth/models/granite.py | 7 +- unsloth/models/llama.py | 190 ++++++++++++++++++++++-------------- unsloth/models/mistral.py | 10 +- unsloth/models/qwen3.py | 8 +- 9 files changed, 173 insertions(+), 111 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5da6ea67fe..f716e50043 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -773,6 +773,26 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO pass +# to move multiple tensors to the same device +def move_to_device(target_device, *tensors): + """ + Move multiple tensors to target device if they're not already there. + + Args: + target_device: The target device to move tensors to + *tensors: Variable number of tensors to potentially move + + Returns: + tuple: The tensors on the target device (same objects if already on device, new if moved) + """ + moved_tensors = [] + for tensor in tensors: + if tensor.device != target_device: + moved_tensors.append(tensor.to(target_device)) + else: + moved_tensors.append(tensor) + return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0] + import transformers.utils.quantization_config transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ # ============================================= diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 1d32f62c86..8562895201 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -322,9 +322,9 @@ def CohereAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q @@ -419,6 +419,10 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + decoder_device = decoder_layer.self_attn.q_proj.weight.device + hidden_states, out_weight, position_ids = move_to_device( + decoder_device, hidden_states, out_weight, position_ids + ) residual = hidden_states hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 7889f673b9..58e8c94d9d 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -113,11 +113,9 @@ def FalconH1Attention_fast_forward( if position_ids is None: # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) + cos, sin = rotary_emb.get_cached(kv_seq_len, device=Q.device) else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) - cos = cos.to(Q.device) - sin = sin.to(Q.device) + cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device=Q.device) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -282,9 +280,9 @@ def FalconH1Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 9138baf78c..8c992aa274 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -170,6 +170,12 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + + decoder_device = decoder_layer.self_attn.q_proj.weight.device + hidden_states, out_weight, position_ids = move_to_device( + decoder_device, hidden_states, out_weight, position_ids + ) + residual = hidden_states hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states, present_key_value = LlamaAttention_fast_forward_inference( @@ -299,7 +305,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= pass def _set_cos_sin_cache(self, seq_len, device, dtype): -# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and + # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len @@ -315,10 +321,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((radians_new, radians_new), dim = -1) # We must do RoPE in float32! - cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype) - sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype) - self.register_buffer("cos_cached", cos, persistent = False) - self.register_buffer("sin_cached", sin, persistent = False) + cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index ebb06de8e4..c4eb5a12c6 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -426,13 +426,9 @@ def Gemma2Model_fast_forward_inference( # For pipeline parallelism, we need to move all tensors to the same device # note that this movement is once per GPU in PP layer_device = decoder_layer.self_attn.q_proj.weight.device - if hidden_states.device != layer_device: - hidden_states = hidden_states.to(layer_device) - if out_weight.device != layer_device: - out_weight = out_weight.to(layer_device) - if position_ids.device != layer_device: - position_ids = position_ids.to(layer_device) - pass + hidden_states, out_weight, position_ids = move_to_device( + layer_device, hidden_states, out_weight, position_ids + ) use_sliding_window = idx % 2 == 0 diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 0f88bbc55e..32e8dbc866 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -395,11 +395,16 @@ def GraniteModel_fast_forward_inference( attention_mask = None pass - position_embeddings = self.model.rotary_emb(hidden_states, position_ids, self.max_seq_length) + position_embeddings = self.model.rotary_emb.get_cached(seq_len = self.max_seq_length, device = hidden_states.device) next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + decoder_device = decoder_layer.self_attn.q_proj.weight.device + hidden_states, position_ids = move_to_device( + decoder_device, hidden_states, position_ids + ) + residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) hidden_states, present_key_value = GraniteAttention_fast_forward_inference( diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ca009f9fae..72664362d4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,6 +20,7 @@ from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ +from ._utils import move_to_device from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype @@ -270,9 +271,9 @@ def LlamaAttention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q @@ -472,11 +473,12 @@ def LlamaAttention_fast_forward( rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - if position_ids is None: - # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) - else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) + # if position_ids is None: + # # Useful for LongRoPE + # cos, sin = rotary_emb.get_cached(kv_seq_len, device = Q.device) + # else: + # cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) + cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) # Q, K = ( # fast_rope_embedding(Q, K, cos, sin) @@ -893,7 +895,7 @@ def LlamaModel_fast_forward( # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor) # unsloth's check for granite too has "version >= 4.45.0 (rightly so)". # so let granite always use the attention refactor implementation. - position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) + position_embeddings = self.rotary_emb.get_cached(seq_len = self.config.max_position_embeddings, device = hidden_states.device) else: position_embeddings = None @@ -1022,16 +1024,9 @@ def LlamaModel_fast_forward_inference_custom( for idx, decoder_layer in enumerate(self.model.layers): decoder_device = decoder_layer.self_attn.q_proj.weight.device - if X.device != decoder_device: - X = X.to(decoder_device) - if residual.device != decoder_device: - residual = residual.to(decoder_device) - if temp_gate.device != decoder_device: - temp_gate = temp_gate.to(decoder_device) - if temp_up.device != decoder_device: - temp_up = temp_up.to(decoder_device) - if position_ids.device != decoder_device: - position_ids = position_ids.to(decoder_device) + X, residual, temp_gate, temp_up, position_ids = move_to_device( + decoder_device, X, residual, temp_gate, temp_up, position_ids + ) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, @@ -1345,9 +1340,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = {} + self.multi_gpu_sin_cached = {} # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1362,30 +1364,36 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype = x.dtype), - self.sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), ) pass - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass @@ -1413,8 +1421,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass pass @@ -1440,6 +1451,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = {} + self.multi_gpu_sin_cached = {} # Normal Llama-3 RoPE inv_freq = 1.0 / ( @@ -1449,7 +1462,12 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.register_buffer("inv_freq", inv_freq, persistent = False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1462,8 +1480,36 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin + pass + + def forward(self, x, position_ids=None, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len is not None and seq_len > self.current_rope_size: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + ) + pass + + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + pass + + def extend_rope_embedding(self, x, seq_len): + if seq_len <= self.current_rope_size: return + # Iteratively grow by increments of 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass # From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41 @@ -1491,28 +1537,6 @@ def apply_scaling(self, freqs: torch.Tensor): new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) pass - - def forward(self, x, position_ids=None, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype = x.dtype), - self.sin_cached[:seq_len].to(dtype = x.dtype), - ) - pass - - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached - pass - - def extend_rope_embedding(self, x, seq_len): - if seq_len <= self.current_rope_size: return - # Iteratively grow by increments of 8192 - self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) - pass pass @@ -1548,6 +1572,10 @@ def __init__(self, self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) + self.multi_gpu_short_cos_cached = {} + self.multi_gpu_short_sin_cached = {} + self.multi_gpu_long_cos_cached = {} + self.multi_gpu_long_sin_cached = {} # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin @@ -1569,18 +1597,26 @@ def __init__(self, # Short and long inv_freq self.register_buffer("short_inv_freq", short_inv_freq, persistent = False) self.register_buffer("long_inv_freq", long_inv_freq, persistent = False) + # Build here to make `torch.jit.trace` work. - # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) - - # Short sequences + # Initialize short sequences cache for all devices dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.register_buffer("short_cos_cached", cos_cached, persistent=False) - self.register_buffer("short_sin_cached", sin_cached, persistent=False) + + for device_idx in range(torch.cuda.device_count()): + device_obj = torch.device(device_idx) + cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) + sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) + self.multi_gpu_short_cos_cached[device_obj] = cos_cached + self.multi_gpu_short_sin_cached[device_obj] = sin_cached + + # dummy so that patch_utils doesn't fail for now + self.short_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.short_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.long_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.long_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1594,39 +1630,43 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.register_buffer("long_cos_cached", cos_cached, persistent=False) - self.register_buffer("long_sin_cached", sin_cached, persistent=False) + self.multi_gpu_long_cos_cached[device] = cos_cached + self.multi_gpu_long_sin_cached[device] = sin_cached + return cos_cached, sin_cached pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - if seq_len < self.original_max_position_embeddings: + if seq_len is not None and seq_len < self.original_max_position_embeddings: return ( - self.short_cos_cached[:seq_len].to(dtype = x.dtype), - self.short_sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_sin_cached[x.device][:seq_len].to(dtype = x.dtype), ) else: return ( - self.long_cos_cached[:seq_len].to(dtype = x.dtype), - self.long_sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_sin_cached[x.device][:seq_len].to(dtype = x.dtype), ) pass pass - def get_cached(self, seq_len = None): - if seq_len < self.original_max_position_embeddings: - return self.short_cos_cached, self.short_sin_cached - return self.long_cos_cached, self.long_sin_cached + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + if seq_len is not None and seq_len < self.original_max_position_embeddings: + return self.multi_gpu_short_cos_cached[device], self.multi_gpu_short_sin_cached[device] + return self.multi_gpu_long_cos_cached[device], self.multi_gpu_long_sin_cached[device] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index f6538c65ac..4aee2cabdf 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -83,18 +83,10 @@ def MistralAttention_fast_forward( # Extend RoPE dynamically to fit in VRAM self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) + cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) if position_ids is None: - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached Q, K = fast_rope_embedding(Q, K, cos, sin) else: - cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) - if cos.device != Q.device: - # without this, even though V is on GPU0, by the time we reach the foward function - # the argument x is on GPU1, and hence cos, sin end up on GPU1 - # this is a hack to get around this quirk - cos = cos.to(Q.device) - sin = sin.to(Q.device) Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 21e52548f8..ffd60b9668 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -116,7 +116,7 @@ def Qwen3Attention_fast_forward( # Useful for LongRoPE cos, sin = rotary_emb.get_cached(kv_seq_len) else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) + cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -299,9 +299,9 @@ def Qwen3Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q From 44955c321d2241c12084f91e2bd35cd5673f86c2 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 4 Jul 2025 05:04:51 +0000 Subject: [PATCH 5/8] Make multiGPU rope a list --- unsloth/models/gemma.py | 18 ++++++------ unsloth/models/gemma2.py | 4 +-- unsloth/models/llama.py | 61 ++++++++++++++++++++-------------------- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 8c992aa274..4d32c42f7b 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -230,8 +230,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = {} - self.multi_gpu_sin_cached = {} + self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() # Build here to make `torch.jit.trace` work. for device in range(torch.cuda.device_count()): @@ -260,8 +260,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # We must do RoPE in float32! cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass @@ -271,15 +271,15 @@ def forward(self, x, position_ids=None, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype=x.dtype), - self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype=x.dtype), + self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype=x.dtype), + self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype=x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index] pass def extend_rope_embedding(self, x, seq_len): @@ -323,8 +323,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # We must do RoPE in float32! cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index c4eb5a12c6..0b3541ea7a 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -114,8 +114,8 @@ def Gemma2Attention_fast_forward( kv_seq_len += past_key_value[0].shape[-2] if position_ids is None: - cos = self.rotary_emb.multi_gpu_cos_cached[Q.device] - sin = self.rotary_emb.multi_gpu_sin_cached[Q.device] + cos = self.rotary_emb.multi_gpu_cos_cached[Q.device.index] + sin = self.rotary_emb.multi_gpu_sin_cached[Q.device.index] Q, K = fast_rope_embedding(Q, K, cos, sin) else: cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 72664362d4..d89faf15b4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1340,8 +1340,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = {} - self.multi_gpu_sin_cached = {} + self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() # Build here to make `torch.jit.trace` work. for device_idx in range(torch.cuda.device_count()): @@ -1366,8 +1366,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass @@ -1377,15 +1377,15 @@ def forward(self, x, position_ids=None, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device] pass def extend_rope_embedding(self, x, seq_len): @@ -1423,8 +1423,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass pass @@ -1451,8 +1451,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = {} - self.multi_gpu_sin_cached = {} + self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() # Normal Llama-3 RoPE inv_freq = 1.0 / ( @@ -1482,8 +1482,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass @@ -1493,15 +1493,15 @@ def forward(self, x, position_ids=None, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device] pass def extend_rope_embedding(self, x, seq_len): @@ -1572,10 +1572,10 @@ def __init__(self, self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) - self.multi_gpu_short_cos_cached = {} - self.multi_gpu_short_sin_cached = {} - self.multi_gpu_long_cos_cached = {} - self.multi_gpu_long_sin_cached = {} + self.multi_gpu_short_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_short_sin_cached = [None]*torch.cuda.device_count() + self.multi_gpu_long_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_long_sin_cached = [None]*torch.cuda.device_count() # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin @@ -1609,8 +1609,8 @@ def __init__(self, device_obj = torch.device(device_idx) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) - self.multi_gpu_short_cos_cached[device_obj] = cos_cached - self.multi_gpu_short_sin_cached[device_obj] = sin_cached + self.multi_gpu_short_cos_cached[device_idx] = cos_cached + self.multi_gpu_short_sin_cached[device_idx] = sin_cached # dummy so that patch_utils doesn't fail for now self.short_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1630,8 +1630,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_long_cos_cached[device] = cos_cached - self.multi_gpu_long_sin_cached[device] = sin_cached + self.multi_gpu_long_cos_cached[device.index] = cos_cached + self.multi_gpu_long_sin_cached[device.index] = sin_cached return cos_cached, sin_cached pass @@ -1642,13 +1642,13 @@ def forward(self, x, position_ids=None, seq_len=None): if seq_len is not None and seq_len < self.original_max_position_embeddings: return ( - self.multi_gpu_short_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_short_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) else: return ( - self.multi_gpu_long_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_long_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) pass pass @@ -1656,9 +1656,10 @@ def forward(self, x, position_ids=None, seq_len=None): def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() + device_index = device.index if hasattr(device, 'index') else device if seq_len is not None and seq_len < self.original_max_position_embeddings: - return self.multi_gpu_short_cos_cached[device], self.multi_gpu_short_sin_cached[device] - return self.multi_gpu_long_cos_cached[device], self.multi_gpu_long_sin_cached[device] + return self.multi_gpu_short_cos_cached[device_index], self.multi_gpu_short_sin_cached[device_index] + return self.multi_gpu_long_cos_cached[device_index], self.multi_gpu_long_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): From e5158da39fb12a5af1c3a9eb973f92399d7b4289 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 4 Jul 2025 10:54:29 +0000 Subject: [PATCH 6/8] Remove unnecessary transfer to CPU --- unsloth/models/cohere.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 8562895201..3ce6d1260d 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -254,8 +254,7 @@ def CohereAttention_fast_forward_inference( do_prefill = False, attention_mask = None, ): - if position_ids is not None: - position_ids = position_ids.to('cpu') + Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value From de38ecea317e7d44a927889f639e8c09aa269899 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 05:13:47 +0000 Subject: [PATCH 7/8] Remove unnecessary move to CPU --- unsloth/models/llama.py | 2 -- unsloth/models/qwen3.py | 10 ++++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d89faf15b4..035e04b170 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -210,8 +210,6 @@ def LlamaAttention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ - if position_ids is not None: - position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index ffd60b9668..a60fd8ab8c 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -66,7 +66,7 @@ def Qwen3Attention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -232,8 +232,6 @@ def Qwen3Attention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ - if position_ids is not None: - position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -262,14 +260,14 @@ def Qwen3Attention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -317,7 +315,7 @@ def Qwen3Attention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) From 324b392587d9ad0739c1be54913ed90a3b787229 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 05:17:30 +0000 Subject: [PATCH 8/8] Donot move inputs to device yet will be handled separately in another PR --- unsloth/models/_utils.py | 19 ---------- unsloth/models/cohere.py | 14 +++----- unsloth/models/gemma.py | 8 +---- unsloth/models/gemma2.py | 17 +++------ unsloth/models/granite.py | 18 ++++------ unsloth/models/llama.py | 75 ++++++++++++++++++--------------------- 6 files changed, 52 insertions(+), 99 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f716e50043..e1511df18a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -773,25 +773,6 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO pass -# to move multiple tensors to the same device -def move_to_device(target_device, *tensors): - """ - Move multiple tensors to target device if they're not already there. - - Args: - target_device: The target device to move tensors to - *tensors: Variable number of tensors to potentially move - - Returns: - tuple: The tensors on the target device (same objects if already on device, new if moved) - """ - moved_tensors = [] - for tensor in tensors: - if tensor.device != target_device: - moved_tensors.append(tensor.to(target_device)) - else: - moved_tensors.append(tensor) - return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0] import transformers.utils.quantization_config transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 3ce6d1260d..168824d635 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -78,7 +78,7 @@ def CohereAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -282,14 +282,14 @@ def CohereAttention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -339,7 +339,7 @@ def CohereAttention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -418,10 +418,6 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - decoder_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, out_weight, position_ids = move_to_device( - decoder_device, hidden_states, out_weight, position_ids - ) residual = hidden_states hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( @@ -473,7 +469,7 @@ def pre_patch(): CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference) PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(CohereForCausalLM) - + import transformers.models.cohere.modeling_cohere transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding return diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 4d32c42f7b..66bd23af95 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -170,12 +170,6 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - - decoder_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, out_weight, position_ids = move_to_device( - decoder_device, hidden_states, out_weight, position_ids - ) - residual = hidden_states hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states, present_key_value = LlamaAttention_fast_forward_inference( @@ -236,7 +230,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # Build here to make `torch.jit.trace` work. for device in range(torch.cuda.device_count()): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype()) - + # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 0b3541ea7a..7b8866f705 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -84,7 +84,7 @@ def Gemma2Attention_fast_forward( padding_mask: Optional[torch.LongTensor] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -281,7 +281,7 @@ def Gemma2Attention_fast_forward_inference( # Only for Gemma2 self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) - + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below # We default to using the config file itself @@ -325,7 +325,7 @@ def Gemma2Attention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -386,7 +386,7 @@ def Gemma2Model_fast_forward_inference( past_key_values, position_ids, attention_mask = None, -): +): out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) @@ -423,13 +423,6 @@ def Gemma2Model_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - # For pipeline parallelism, we need to move all tensors to the same device - # note that this movement is once per GPU in PP - layer_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, out_weight, position_ids = move_to_device( - layer_device, hidden_states, out_weight, position_ids - ) - use_sliding_window = idx % 2 == 0 residual = hidden_states @@ -487,7 +480,7 @@ def pre_patch(): Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference) PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(Gemma2ForCausalLM) - + # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 32e8dbc866..e9a62a7d42 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -71,7 +71,7 @@ def GraniteAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -162,7 +162,7 @@ def GraniteAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass - + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -256,7 +256,7 @@ def GraniteAttention_fast_forward_inference( use_sliding_window = False, position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): - + assert position_embeddings is not None, f"Granite model requires position embeddings to be specified" Xn = hidden_states @@ -326,7 +326,7 @@ def GraniteAttention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -349,7 +349,7 @@ def GraniteAttention_fast_forward_inference( Qn *= self.scaling A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) - + # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) @@ -400,11 +400,6 @@ def GraniteModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - decoder_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, position_ids = move_to_device( - decoder_device, hidden_states, position_ids - ) - residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) hidden_states, present_key_value = GraniteAttention_fast_forward_inference( @@ -537,7 +532,7 @@ def post_patch(model, tokenizer): elif hasattr(module, "short_cos_cached") and \ (module.short_cos_cached.dtype != correct_dtype): - + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) module.short_sin_cached = module.short_sin_cached.to(correct_dtype) pass @@ -552,4 +547,3 @@ def post_patch(model, tokenizer): return model, tokenizer pass pass - diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 035e04b170..e1bc14bb1d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,7 +20,6 @@ from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ -from ._utils import move_to_device from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype @@ -117,12 +116,12 @@ def _fast_prepare_inputs_for_generation(self, input_ids, attention_mask=None, ** else: bs, cache_length = input_ids.shape input_ids = input_ids[:,[-1]] - + # Get to the base model base_model = self if hasattr(base_model, 'base_model_prefix'): base_model = getattr(base_model, base_model.base_model_prefix) - + if hasattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position"): def needs_device_kw(fn) -> bool: try: @@ -238,14 +237,14 @@ def LlamaAttention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -287,7 +286,7 @@ def LlamaAttention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -349,10 +348,10 @@ def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None, gate_multip # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") gate = fast_linear_forward(self.gate_proj, X, out = temp_gate) - + if gate_multiplier is not None: gate *= gate_multiplier - + up = fast_linear_forward(self. up_proj, X, out = temp_up) gate = torch_nn_functional_silu(gate, inplace = True) @@ -435,7 +434,7 @@ def LlamaAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -654,7 +653,7 @@ def LlamaModel_fast_forward( return_dict: Optional[bool] = None, *args, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions assert(output_attentions is False) output_hidden_states = ( @@ -689,7 +688,7 @@ def LlamaModel_fast_forward( inputs_embeds = inputs_embeds[:,:self.max_seq_length,:] pass pass - + past_key_values_length = 0 if past_key_values is not None: @@ -776,7 +775,7 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': + elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': attention_mask = None padding_mask = None else: @@ -1021,10 +1020,6 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - decoder_device = decoder_layer.self_attn.q_proj.weight.device - X, residual, temp_gate, temp_up, position_ids = move_to_device( - decoder_device, X, residual, temp_gate, temp_up, position_ids - ) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, @@ -1140,7 +1135,7 @@ def _CausalLM_fast_forward( logit_scaling = getattr(self.config, "logit_scale", 0) dtype = lm_head.dtype num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) - + # Move items to same device as lm_head hidden_states = hidden_states.to(lm_head_device) if labels is not None: labels = labels.to(lm_head_device) @@ -1167,7 +1162,7 @@ def _CausalLM_fast_forward( RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True - + if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) @@ -1281,16 +1276,16 @@ def PeftModel_fast_forward( **kwargs, ): is_classification = "Classification" in str(type( self.base_model.model)) - if is_classification: + if is_classification: #causal_mask = causal_mask, return self.base_model( input_ids = input_ids, - attention_mask = attention_mask, - inputs_embeds = inputs_embeds, - labels = labels, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, output_attentions = output_attentions, - output_hidden_states = output_hidden_states, - return_dict = return_dict, + output_hidden_states = output_hidden_states, + return_dict = return_dict, **kwargs, ) else: @@ -1344,7 +1339,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # Build here to make `torch.jit.trace` work. for device_idx in range(torch.cuda.device_count()): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) - + # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1462,7 +1457,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # Build here to make `torch.jit.trace` work. for device_idx in range(torch.cuda.device_count()): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) - + # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1472,7 +1467,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len - + t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.inv_freq) @@ -1595,21 +1590,21 @@ def __init__(self, # Short and long inv_freq self.register_buffer("short_inv_freq", short_inv_freq, persistent = False) self.register_buffer("long_inv_freq", long_inv_freq, persistent = False) - + # Build here to make `torch.jit.trace` work. # Initialize short sequences cache for all devices dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - + for device_idx in range(torch.cuda.device_count()): device_obj = torch.device(device_idx) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) self.multi_gpu_short_cos_cached[device_idx] = cos_cached self.multi_gpu_short_sin_cached[device_idx] = sin_cached - + # dummy so that patch_utils doesn't fail for now self.short_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.short_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1621,7 +1616,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len - + t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float() # Long sequences freqs = torch.outer(t, self.long_inv_freq) @@ -1787,7 +1782,7 @@ def from_pretrained( max_lora_rank = 16, disable_log_stats = False, unsloth_vllm_standby = False, - num_labels = None, + num_labels = None, **kwargs, ): os.environ["UNSLOTH_USE_NEW_MODEL"] = "0" @@ -1859,7 +1854,7 @@ def from_pretrained( if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" model_patcher.pre_patch() - get_statistics() # For debugging - we use a download counter to see if environments are not breaking + get_statistics() # For debugging - we use a download counter to see if environments are not breaking if dtype is None: dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 @@ -1946,7 +1941,7 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - + if num_labels is not None: model = AutoModelForSequenceClassification.from_pretrained( model_name, @@ -2044,7 +2039,7 @@ def from_pretrained( except: raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass - + import transformers.trainer items_in_trainer = dir(transformers.trainer) good_items = [] @@ -2391,7 +2386,7 @@ def get_peft_model( ) loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1) pass - + if hasattr(model.config, "quantization_config"): raise ValueError( "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\ @@ -2521,7 +2516,7 @@ def get_peft_model( is_classification = "Classification" in str(type(model)) # Get LoRA - # + # arguments = dict( r = r, @@ -2547,7 +2542,7 @@ def get_peft_model( input_embeddings_device = model.get_input_embeddings().weight.device if is_classification: output_embeddings_device = model.score.weight.device - else: + else: output_embeddings_device = model.get_output_embeddings().weight.device if use_gradient_checkpointing == "unsloth": @@ -2707,7 +2702,7 @@ def patch_peft_model( # model.peft_config[active_adapter].revision = f"unsloth" pass - from transformers.trainer import Trainer + from transformers.trainer import Trainer if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop": raise RuntimeError("Unsloth: Unsuccessfully patched Trainer! Please file a bug report!") pass @@ -2830,7 +2825,7 @@ def patch_peft_model( internal_model.max_seq_length = max_seq_length internal_model = internal_model.model pass - internal_model.max_seq_length = max_seq_length + internal_model.max_seq_length = max_seq_length # Patch tokenizer to pad to the right internal_model = model