From 2b95cc7c4dce99136fb4451087df852fb051465f Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 2 Jul 2025 13:22:34 +0000 Subject: [PATCH 01/17] Move tensors to right devices --- unsloth/models/llama.py | 16 ++++++++++++++-- unsloth/models/mistral.py | 5 +++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d0ff413925..44961204ab 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -209,6 +209,8 @@ def LlamaAttention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ + if position_ids is not None: + position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -269,8 +271,8 @@ def LlamaAttention_fast_forward_inference( # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q @@ -1019,6 +1021,16 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + print(f'Inference through layer {idx}') + decoder_device = decoder_layer.self_attn.q_proj.weight.device + if X.device != decoder_device: + X = X.to(decoder_device) + if residual.device != decoder_device: + residual = residual.to(decoder_device) + if temp_gate.device != decoder_device: + temp_gate = temp_gate.to(decoder_device) + if temp_up.device != decoder_device: + temp_up = temp_up.to(decoder_device) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index ef71c89c4a..78a8a792b9 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -89,6 +89,11 @@ def MistralAttention_fast_forward( Q, K = fast_rope_embedding(Q, K, cos, sin) else: cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + if cos.device != Q.device: + cos = cos.to(Q.device) + sin = sin.to(Q.device) + if Q.device != K.device: + raise ValueError("Q and K must be on the same device") Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass From 2ca7875433a490daa079f9b4f5379668c39e8001 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Wed, 2 Jul 2025 15:29:52 +0000 Subject: [PATCH 02/17] fix multi gpu for non mistral models --- unsloth/models/cohere.py | 6 ++++-- unsloth/models/falcon_h1.py | 6 ++++-- unsloth/models/gemma2.py | 4 ++-- unsloth/models/llama.py | 3 ++- unsloth/models/mistral.py | 5 +++-- unsloth/models/qwen3.py | 6 ++++-- 6 files changed, 19 insertions(+), 11 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 25704301d1..1d32f62c86 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -254,6 +254,8 @@ def CohereAttention_fast_forward_inference( do_prefill = False, attention_mask = None, ): + if position_ids is not None: + position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -321,8 +323,8 @@ def CohereAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 8978e4db0d..7889f673b9 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -116,6 +116,8 @@ def FalconH1Attention_fast_forward( cos, sin = rotary_emb.get_cached(kv_seq_len) else: cos, sin = rotary_emb(V, seq_len = kv_seq_len) + cos = cos.to(Q.device) + sin = sin.to(Q.device) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -281,8 +283,8 @@ def FalconH1Attention_fast_forward_inference( # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 23b91ff6f3..94635beb4c 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -307,8 +307,8 @@ def Gemma2Attention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1) - sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1) + cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1).to(Qn.device) + sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 44961204ab..ca009f9fae 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1021,7 +1021,6 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - print(f'Inference through layer {idx}') decoder_device = decoder_layer.self_attn.q_proj.weight.device if X.device != decoder_device: X = X.to(decoder_device) @@ -1031,6 +1030,8 @@ def LlamaModel_fast_forward_inference_custom( temp_gate = temp_gate.to(decoder_device) if temp_up.device != decoder_device: temp_up = temp_up.to(decoder_device) + if position_ids.device != decoder_device: + position_ids = position_ids.to(decoder_device) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 78a8a792b9..f6538c65ac 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -90,10 +90,11 @@ def MistralAttention_fast_forward( else: cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) if cos.device != Q.device: + # without this, even though V is on GPU0, by the time we reach the foward function + # the argument x is on GPU1, and hence cos, sin end up on GPU1 + # this is a hack to get around this quirk cos = cos.to(Q.device) sin = sin.to(Q.device) - if Q.device != K.device: - raise ValueError("Q and K must be on the same device") Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 83c9dbea0a..21e52548f8 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -232,6 +232,8 @@ def Qwen3Attention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ + if position_ids is not None: + position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -298,8 +300,8 @@ def Qwen3Attention_fast_forward_inference( # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1) - sin = sin[position_ids].unsqueeze(1) + cos = cos[position_ids].unsqueeze(1).to(Qn.device) + sin = sin[position_ids].unsqueeze(1).to(Qn.device) h = self.half_head_dim RH_Q = self.RH_Q From 03c57c1c47c560cd6f56c58cf0843be0e8a65402 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 3 Jul 2025 15:57:29 +0000 Subject: [PATCH 03/17] multi GPU RoPE for gemma2 --- unsloth/models/gemma.py | 33 ++++++++++++++++++++++----------- unsloth/models/gemma2.py | 24 ++++++++++++++++++------ 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 8ad1c7e62d..9138baf78c 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -224,9 +224,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = {} + self.multi_gpu_sin_cached = {} # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device in range(torch.cuda.device_count()): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -245,32 +252,36 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((radians_new, radians_new), dim = -1) # We must do RoPE in float32! - cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype) - sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype) - self.register_buffer("cos_cached", cos, persistent = False) - self.register_buffer("sin_cached", sin, persistent = False) + cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), + self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype=x.dtype), + self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype=x.dtype), ) pass - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = math.ceil(seq_len / 8192) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = "cuda", dtype = x.dtype) + for device in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype) pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 94635beb4c..ebb06de8e4 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -114,11 +114,11 @@ def Gemma2Attention_fast_forward( kv_seq_len += past_key_value[0].shape[-2] if position_ids is None: - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached + cos = self.rotary_emb.multi_gpu_cos_cached[Q.device] + sin = self.rotary_emb.multi_gpu_sin_cached[Q.device] Q, K = fast_rope_embedding(Q, K, cos, sin) else: - cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) + cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass @@ -307,8 +307,9 @@ def Gemma2Attention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1).to(Qn.device) - sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q @@ -385,7 +386,7 @@ def Gemma2Model_fast_forward_inference( past_key_values, position_ids, attention_mask = None, -): +): out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) @@ -422,6 +423,17 @@ def Gemma2Model_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + # For pipeline parallelism, we need to move all tensors to the same device + # note that this movement is once per GPU in PP + layer_device = decoder_layer.self_attn.q_proj.weight.device + if hidden_states.device != layer_device: + hidden_states = hidden_states.to(layer_device) + if out_weight.device != layer_device: + out_weight = out_weight.to(layer_device) + if position_ids.device != layer_device: + position_ids = position_ids.to(layer_device) + pass + use_sliding_window = idx % 2 == 0 residual = hidden_states From a937baa2f3da1228d8691d587c1aac1c337a6372 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 3 Jul 2025 17:22:53 +0000 Subject: [PATCH 04/17] Finish up multi GPU inference --- unsloth/models/_utils.py | 20 ++++ unsloth/models/cohere.py | 10 +- unsloth/models/falcon_h1.py | 12 +-- unsloth/models/gemma.py | 17 +++- unsloth/models/gemma2.py | 10 +- unsloth/models/granite.py | 7 +- unsloth/models/llama.py | 190 ++++++++++++++++++++++-------------- unsloth/models/mistral.py | 10 +- unsloth/models/qwen3.py | 8 +- 9 files changed, 173 insertions(+), 111 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 5da6ea67fe..f716e50043 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -773,6 +773,26 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO pass +# to move multiple tensors to the same device +def move_to_device(target_device, *tensors): + """ + Move multiple tensors to target device if they're not already there. + + Args: + target_device: The target device to move tensors to + *tensors: Variable number of tensors to potentially move + + Returns: + tuple: The tensors on the target device (same objects if already on device, new if moved) + """ + moved_tensors = [] + for tensor in tensors: + if tensor.device != target_device: + moved_tensors.append(tensor.to(target_device)) + else: + moved_tensors.append(tensor) + return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0] + import transformers.utils.quantization_config transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ # ============================================= diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 1d32f62c86..8562895201 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -322,9 +322,9 @@ def CohereAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q @@ -419,6 +419,10 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + decoder_device = decoder_layer.self_attn.q_proj.weight.device + hidden_states, out_weight, position_ids = move_to_device( + decoder_device, hidden_states, out_weight, position_ids + ) residual = hidden_states hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 7889f673b9..58e8c94d9d 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -113,11 +113,9 @@ def FalconH1Attention_fast_forward( if position_ids is None: # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) + cos, sin = rotary_emb.get_cached(kv_seq_len, device=Q.device) else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) - cos = cos.to(Q.device) - sin = sin.to(Q.device) + cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device=Q.device) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -282,9 +280,9 @@ def FalconH1Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 9138baf78c..8c992aa274 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -170,6 +170,12 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + + decoder_device = decoder_layer.self_attn.q_proj.weight.device + hidden_states, out_weight, position_ids = move_to_device( + decoder_device, hidden_states, out_weight, position_ids + ) + residual = hidden_states hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states, present_key_value = LlamaAttention_fast_forward_inference( @@ -299,7 +305,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= pass def _set_cos_sin_cache(self, seq_len, device, dtype): -# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and + # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len @@ -315,10 +321,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((radians_new, radians_new), dim = -1) # We must do RoPE in float32! - cos = emb.cos().to(device = "cuda", non_blocking = True)#, dtype = dtype) - sin = emb.sin().to(device = "cuda", non_blocking = True)#, dtype = dtype) - self.register_buffer("cos_cached", cos, persistent = False) - self.register_buffer("sin_cached", sin, persistent = False) + cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) + sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index ebb06de8e4..c4eb5a12c6 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -426,13 +426,9 @@ def Gemma2Model_fast_forward_inference( # For pipeline parallelism, we need to move all tensors to the same device # note that this movement is once per GPU in PP layer_device = decoder_layer.self_attn.q_proj.weight.device - if hidden_states.device != layer_device: - hidden_states = hidden_states.to(layer_device) - if out_weight.device != layer_device: - out_weight = out_weight.to(layer_device) - if position_ids.device != layer_device: - position_ids = position_ids.to(layer_device) - pass + hidden_states, out_weight, position_ids = move_to_device( + layer_device, hidden_states, out_weight, position_ids + ) use_sliding_window = idx % 2 == 0 diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 0f88bbc55e..32e8dbc866 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -395,11 +395,16 @@ def GraniteModel_fast_forward_inference( attention_mask = None pass - position_embeddings = self.model.rotary_emb(hidden_states, position_ids, self.max_seq_length) + position_embeddings = self.model.rotary_emb.get_cached(seq_len = self.max_seq_length, device = hidden_states.device) next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + decoder_device = decoder_layer.self_attn.q_proj.weight.device + hidden_states, position_ids = move_to_device( + decoder_device, hidden_states, position_ids + ) + residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) hidden_states, present_key_value = GraniteAttention_fast_forward_inference( diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ca009f9fae..72664362d4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,6 +20,7 @@ from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ +from ._utils import move_to_device from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype @@ -270,9 +271,9 @@ def LlamaAttention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q @@ -472,11 +473,12 @@ def LlamaAttention_fast_forward( rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - if position_ids is None: - # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) - else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) + # if position_ids is None: + # # Useful for LongRoPE + # cos, sin = rotary_emb.get_cached(kv_seq_len, device = Q.device) + # else: + # cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) + cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) # Q, K = ( # fast_rope_embedding(Q, K, cos, sin) @@ -893,7 +895,7 @@ def LlamaModel_fast_forward( # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor) # unsloth's check for granite too has "version >= 4.45.0 (rightly so)". # so let granite always use the attention refactor implementation. - position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings) + position_embeddings = self.rotary_emb.get_cached(seq_len = self.config.max_position_embeddings, device = hidden_states.device) else: position_embeddings = None @@ -1022,16 +1024,9 @@ def LlamaModel_fast_forward_inference_custom( for idx, decoder_layer in enumerate(self.model.layers): decoder_device = decoder_layer.self_attn.q_proj.weight.device - if X.device != decoder_device: - X = X.to(decoder_device) - if residual.device != decoder_device: - residual = residual.to(decoder_device) - if temp_gate.device != decoder_device: - temp_gate = temp_gate.to(decoder_device) - if temp_up.device != decoder_device: - temp_up = temp_up.to(decoder_device) - if position_ids.device != decoder_device: - position_ids = position_ids.to(decoder_device) + X, residual, temp_gate, temp_up, position_ids = move_to_device( + decoder_device, X, residual, temp_gate, temp_up, position_ids + ) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, @@ -1345,9 +1340,16 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = {} + self.multi_gpu_sin_cached = {} # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1362,30 +1364,36 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.cos_cached[:seq_len].to(dtype = x.dtype), - self.sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), ) pass - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass @@ -1413,8 +1421,11 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin pass pass @@ -1440,6 +1451,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) + self.multi_gpu_cos_cached = {} + self.multi_gpu_sin_cached = {} # Normal Llama-3 RoPE inv_freq = 1.0 / ( @@ -1449,7 +1462,12 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.register_buffer("inv_freq", inv_freq, persistent = False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) + + # dummy so that patch_utils doesn't fail for now + self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1462,8 +1480,36 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype=dtype, device=device, non_blocking=True), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype=dtype, device=device, non_blocking=True), persistent=False) + cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) + sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) + self.multi_gpu_cos_cached[device] = cos + self.multi_gpu_sin_cached[device] = sin + return cos, sin + pass + + def forward(self, x, position_ids=None, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len is not None and seq_len > self.current_rope_size: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + ) + pass + + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + pass + + def extend_rope_embedding(self, x, seq_len): + if seq_len <= self.current_rope_size: return + # Iteratively grow by increments of 8192 + self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass # From https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/api/model.py#L41 @@ -1491,28 +1537,6 @@ def apply_scaling(self, freqs: torch.Tensor): new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) pass - - def forward(self, x, position_ids=None, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype = x.dtype), - self.sin_cached[:seq_len].to(dtype = x.dtype), - ) - pass - - def get_cached(self, seq_len = None): - return self.cos_cached, self.sin_cached - pass - - def extend_rope_embedding(self, x, seq_len): - if seq_len <= self.current_rope_size: return - # Iteratively grow by increments of 8192 - self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) - pass pass @@ -1548,6 +1572,10 @@ def __init__(self, self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) + self.multi_gpu_short_cos_cached = {} + self.multi_gpu_short_sin_cached = {} + self.multi_gpu_long_cos_cached = {} + self.multi_gpu_long_sin_cached = {} # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin @@ -1569,18 +1597,26 @@ def __init__(self, # Short and long inv_freq self.register_buffer("short_inv_freq", short_inv_freq, persistent = False) self.register_buffer("long_inv_freq", long_inv_freq, persistent = False) + # Build here to make `torch.jit.trace` work. - # self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype()) - - # Short sequences + # Initialize short sequences cache for all devices dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.register_buffer("short_cos_cached", cos_cached, persistent=False) - self.register_buffer("short_sin_cached", sin_cached, persistent=False) + + for device_idx in range(torch.cuda.device_count()): + device_obj = torch.device(device_idx) + cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) + sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) + self.multi_gpu_short_cos_cached[device_obj] = cos_cached + self.multi_gpu_short_sin_cached[device_obj] = sin_cached + + # dummy so that patch_utils doesn't fail for now + self.short_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.short_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.long_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) + self.long_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) pass def _set_cos_sin_cache(self, seq_len, device, dtype): @@ -1594,39 +1630,43 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.register_buffer("long_cos_cached", cos_cached, persistent=False) - self.register_buffer("long_sin_cached", sin_cached, persistent=False) + self.multi_gpu_long_cos_cached[device] = cos_cached + self.multi_gpu_long_sin_cached[device] = sin_cached + return cos_cached, sin_cached pass def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.current_rope_size: + if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - if seq_len < self.original_max_position_embeddings: + if seq_len is not None and seq_len < self.original_max_position_embeddings: return ( - self.short_cos_cached[:seq_len].to(dtype = x.dtype), - self.short_sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_sin_cached[x.device][:seq_len].to(dtype = x.dtype), ) else: return ( - self.long_cos_cached[:seq_len].to(dtype = x.dtype), - self.long_sin_cached[:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_cos_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_sin_cached[x.device][:seq_len].to(dtype = x.dtype), ) pass pass - def get_cached(self, seq_len = None): - if seq_len < self.original_max_position_embeddings: - return self.short_cos_cached, self.short_sin_cached - return self.long_cos_cached, self.long_sin_cached + def get_cached(self, seq_len = None, device = None): + if device is None: + device = torch.cuda.current_device() + if seq_len is not None and seq_len < self.original_max_position_embeddings: + return self.multi_gpu_short_cos_cached[device], self.multi_gpu_short_sin_cached[device] + return self.multi_gpu_long_cos_cached[device], self.multi_gpu_long_sin_cached[device] pass def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - self._set_cos_sin_cache(self.current_rope_size, device = DEVICE_TYPE, dtype = x.dtype) + for device_idx in range(torch.cuda.device_count()): + self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index f6538c65ac..4aee2cabdf 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -83,18 +83,10 @@ def MistralAttention_fast_forward( # Extend RoPE dynamically to fit in VRAM self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) + cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) if position_ids is None: - cos = self.rotary_emb.cos_cached - sin = self.rotary_emb.sin_cached Q, K = fast_rope_embedding(Q, K, cos, sin) else: - cos, sin = self.rotary_emb(V, seq_len = kv_seq_len) - if cos.device != Q.device: - # without this, even though V is on GPU0, by the time we reach the foward function - # the argument x is on GPU1, and hence cos, sin end up on GPU1 - # this is a hack to get around this quirk - cos = cos.to(Q.device) - sin = sin.to(Q.device) Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 21e52548f8..ffd60b9668 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -116,7 +116,7 @@ def Qwen3Attention_fast_forward( # Useful for LongRoPE cos, sin = rotary_emb.get_cached(kv_seq_len) else: - cos, sin = rotary_emb(V, seq_len = kv_seq_len) + cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -299,9 +299,9 @@ def Qwen3Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len) - cos = cos[position_ids].unsqueeze(1).to(Qn.device) - sin = sin[position_ids].unsqueeze(1).to(Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim RH_Q = self.RH_Q From 44955c321d2241c12084f91e2bd35cd5673f86c2 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 4 Jul 2025 05:04:51 +0000 Subject: [PATCH 05/17] Make multiGPU rope a list --- unsloth/models/gemma.py | 18 ++++++------ unsloth/models/gemma2.py | 4 +-- unsloth/models/llama.py | 61 ++++++++++++++++++++-------------------- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 8c992aa274..4d32c42f7b 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -230,8 +230,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = {} - self.multi_gpu_sin_cached = {} + self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() # Build here to make `torch.jit.trace` work. for device in range(torch.cuda.device_count()): @@ -260,8 +260,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # We must do RoPE in float32! cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass @@ -271,15 +271,15 @@ def forward(self, x, position_ids=None, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype=x.dtype), - self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype=x.dtype), + self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype=x.dtype), + self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype=x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index] pass def extend_rope_embedding(self, x, seq_len): @@ -323,8 +323,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # We must do RoPE in float32! cos = emb.cos().to(device = device, non_blocking = True)#, dtype = dtype) sin = emb.sin().to(device = device, non_blocking = True)#, dtype = dtype) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass pass diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index c4eb5a12c6..0b3541ea7a 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -114,8 +114,8 @@ def Gemma2Attention_fast_forward( kv_seq_len += past_key_value[0].shape[-2] if position_ids is None: - cos = self.rotary_emb.multi_gpu_cos_cached[Q.device] - sin = self.rotary_emb.multi_gpu_sin_cached[Q.device] + cos = self.rotary_emb.multi_gpu_cos_cached[Q.device.index] + sin = self.rotary_emb.multi_gpu_sin_cached[Q.device.index] Q, K = fast_rope_embedding(Q, K, cos, sin) else: cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 72664362d4..d89faf15b4 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1340,8 +1340,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = {} - self.multi_gpu_sin_cached = {} + self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() # Build here to make `torch.jit.trace` work. for device_idx in range(torch.cuda.device_count()): @@ -1366,8 +1366,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass @@ -1377,15 +1377,15 @@ def forward(self, x, position_ids=None, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device] pass def extend_rope_embedding(self, x, seq_len): @@ -1423,8 +1423,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass pass @@ -1451,8 +1451,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = {} - self.multi_gpu_sin_cached = {} + self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() # Normal Llama-3 RoPE inv_freq = 1.0 / ( @@ -1482,8 +1482,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().to(dtype=dtype, device=device, non_blocking=True) sin = emb.sin().to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_cos_cached[device] = cos - self.multi_gpu_sin_cached[device] = sin + self.multi_gpu_cos_cached[device.index] = cos + self.multi_gpu_sin_cached[device.index] = sin return cos, sin pass @@ -1493,15 +1493,15 @@ def forward(self, x, position_ids=None, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( - self.multi_gpu_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device], self.multi_gpu_sin_cached[device] + return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device] pass def extend_rope_embedding(self, x, seq_len): @@ -1572,10 +1572,10 @@ def __init__(self, self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) - self.multi_gpu_short_cos_cached = {} - self.multi_gpu_short_sin_cached = {} - self.multi_gpu_long_cos_cached = {} - self.multi_gpu_long_sin_cached = {} + self.multi_gpu_short_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_short_sin_cached = [None]*torch.cuda.device_count() + self.multi_gpu_long_cos_cached = [None]*torch.cuda.device_count() + self.multi_gpu_long_sin_cached = [None]*torch.cuda.device_count() # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin @@ -1609,8 +1609,8 @@ def __init__(self, device_obj = torch.device(device_idx) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) - self.multi_gpu_short_cos_cached[device_obj] = cos_cached - self.multi_gpu_short_sin_cached[device_obj] = sin_cached + self.multi_gpu_short_cos_cached[device_idx] = cos_cached + self.multi_gpu_short_sin_cached[device_idx] = sin_cached # dummy so that patch_utils doesn't fail for now self.short_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1630,8 +1630,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): emb = torch.cat((freqs, freqs), dim=-1) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device, non_blocking=True) - self.multi_gpu_long_cos_cached[device] = cos_cached - self.multi_gpu_long_sin_cached[device] = sin_cached + self.multi_gpu_long_cos_cached[device.index] = cos_cached + self.multi_gpu_long_sin_cached[device.index] = sin_cached return cos_cached, sin_cached pass @@ -1642,13 +1642,13 @@ def forward(self, x, position_ids=None, seq_len=None): if seq_len is not None and seq_len < self.original_max_position_embeddings: return ( - self.multi_gpu_short_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_short_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) else: return ( - self.multi_gpu_long_cos_cached[x.device][:seq_len].to(dtype = x.dtype), - self.multi_gpu_long_sin_cached[x.device][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), ) pass pass @@ -1656,9 +1656,10 @@ def forward(self, x, position_ids=None, seq_len=None): def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() + device_index = device.index if hasattr(device, 'index') else device if seq_len is not None and seq_len < self.original_max_position_embeddings: - return self.multi_gpu_short_cos_cached[device], self.multi_gpu_short_sin_cached[device] - return self.multi_gpu_long_cos_cached[device], self.multi_gpu_long_sin_cached[device] + return self.multi_gpu_short_cos_cached[device_index], self.multi_gpu_short_sin_cached[device_index] + return self.multi_gpu_long_cos_cached[device_index], self.multi_gpu_long_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): From e5158da39fb12a5af1c3a9eb973f92399d7b4289 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Fri, 4 Jul 2025 10:54:29 +0000 Subject: [PATCH 06/17] Remove unnecessary transfer to CPU --- unsloth/models/cohere.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 8562895201..3ce6d1260d 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -254,8 +254,7 @@ def CohereAttention_fast_forward_inference( do_prefill = False, attention_mask = None, ): - if position_ids is not None: - position_ids = position_ids.to('cpu') + Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value From de38ecea317e7d44a927889f639e8c09aa269899 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 05:13:47 +0000 Subject: [PATCH 07/17] Remove unnecessary move to CPU --- unsloth/models/llama.py | 2 -- unsloth/models/qwen3.py | 10 ++++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d89faf15b4..035e04b170 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -210,8 +210,6 @@ def LlamaAttention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ - if position_ids is not None: - position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index ffd60b9668..a60fd8ab8c 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -66,7 +66,7 @@ def Qwen3Attention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -232,8 +232,6 @@ def Qwen3Attention_fast_forward_inference( This means we can pass in a row of Q, but we need to remember K and V, which are called the KV cache. """ - if position_ids is not None: - position_ids = position_ids.to('cpu') Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -262,14 +260,14 @@ def Qwen3Attention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -317,7 +315,7 @@ def Qwen3Attention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) From 324b392587d9ad0739c1be54913ed90a3b787229 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 05:17:30 +0000 Subject: [PATCH 08/17] Donot move inputs to device yet will be handled separately in another PR --- unsloth/models/_utils.py | 19 ---------- unsloth/models/cohere.py | 14 +++----- unsloth/models/gemma.py | 8 +---- unsloth/models/gemma2.py | 17 +++------ unsloth/models/granite.py | 18 ++++------ unsloth/models/llama.py | 75 ++++++++++++++++++--------------------- 6 files changed, 52 insertions(+), 99 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index f716e50043..e1511df18a 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -773,25 +773,6 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO pass -# to move multiple tensors to the same device -def move_to_device(target_device, *tensors): - """ - Move multiple tensors to target device if they're not already there. - - Args: - target_device: The target device to move tensors to - *tensors: Variable number of tensors to potentially move - - Returns: - tuple: The tensors on the target device (same objects if already on device, new if moved) - """ - moved_tensors = [] - for tensor in tensors: - if tensor.device != target_device: - moved_tensors.append(tensor.to(target_device)) - else: - moved_tensors.append(tensor) - return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0] import transformers.utils.quantization_config transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 3ce6d1260d..168824d635 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -78,7 +78,7 @@ def CohereAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -282,14 +282,14 @@ def CohereAttention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = "cuda:0") else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -339,7 +339,7 @@ def CohereAttention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -418,10 +418,6 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - decoder_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, out_weight, position_ids = move_to_device( - decoder_device, hidden_states, out_weight, position_ids - ) residual = hidden_states hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( @@ -473,7 +469,7 @@ def pre_patch(): CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference) PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(CohereForCausalLM) - + import transformers.models.cohere.modeling_cohere transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding return diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 4d32c42f7b..66bd23af95 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -170,12 +170,6 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - - decoder_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, out_weight, position_ids = move_to_device( - decoder_device, hidden_states, out_weight, position_ids - ) - residual = hidden_states hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states, present_key_value = LlamaAttention_fast_forward_inference( @@ -236,7 +230,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # Build here to make `torch.jit.trace` work. for device in range(torch.cuda.device_count()): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype()) - + # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 0b3541ea7a..7b8866f705 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -84,7 +84,7 @@ def Gemma2Attention_fast_forward( padding_mask: Optional[torch.LongTensor] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -281,7 +281,7 @@ def Gemma2Attention_fast_forward_inference( # Only for Gemma2 self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) - + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below # We default to using the config file itself @@ -325,7 +325,7 @@ def Gemma2Attention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -386,7 +386,7 @@ def Gemma2Model_fast_forward_inference( past_key_values, position_ids, attention_mask = None, -): +): out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) @@ -423,13 +423,6 @@ def Gemma2Model_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - # For pipeline parallelism, we need to move all tensors to the same device - # note that this movement is once per GPU in PP - layer_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, out_weight, position_ids = move_to_device( - layer_device, hidden_states, out_weight, position_ids - ) - use_sliding_window = idx % 2 == 0 residual = hidden_states @@ -487,7 +480,7 @@ def pre_patch(): Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference) PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(Gemma2ForCausalLM) - + # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 32e8dbc866..e9a62a7d42 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -71,7 +71,7 @@ def GraniteAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -162,7 +162,7 @@ def GraniteAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass - + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -256,7 +256,7 @@ def GraniteAttention_fast_forward_inference( use_sliding_window = False, position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): - + assert position_embeddings is not None, f"Granite model requires position embeddings to be specified" Xn = hidden_states @@ -326,7 +326,7 @@ def GraniteAttention_fast_forward_inference( torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -349,7 +349,7 @@ def GraniteAttention_fast_forward_inference( Qn *= self.scaling A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) - + # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) @@ -400,11 +400,6 @@ def GraniteModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - decoder_device = decoder_layer.self_attn.q_proj.weight.device - hidden_states, position_ids = move_to_device( - decoder_device, hidden_states, position_ids - ) - residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) hidden_states, present_key_value = GraniteAttention_fast_forward_inference( @@ -537,7 +532,7 @@ def post_patch(model, tokenizer): elif hasattr(module, "short_cos_cached") and \ (module.short_cos_cached.dtype != correct_dtype): - + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) module.short_sin_cached = module.short_sin_cached.to(correct_dtype) pass @@ -552,4 +547,3 @@ def post_patch(model, tokenizer): return model, tokenizer pass pass - diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 035e04b170..e1bc14bb1d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,7 +20,6 @@ from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ -from ._utils import move_to_device from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype @@ -117,12 +116,12 @@ def _fast_prepare_inputs_for_generation(self, input_ids, attention_mask=None, ** else: bs, cache_length = input_ids.shape input_ids = input_ids[:,[-1]] - + # Get to the base model base_model = self if hasattr(base_model, 'base_model_prefix'): base_model = getattr(base_model, base_model.base_model_prefix) - + if hasattr(base_model, "_prepare_4d_causal_attention_mask_with_cache_position"): def needs_device_kw(fn) -> bool: try: @@ -238,14 +237,14 @@ def LlamaAttention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -287,7 +286,7 @@ def LlamaAttention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -349,10 +348,10 @@ def fast_swiglu_inference(self, X, temp_gate = None, temp_up = None, gate_multip # temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0") gate = fast_linear_forward(self.gate_proj, X, out = temp_gate) - + if gate_multiplier is not None: gate *= gate_multiplier - + up = fast_linear_forward(self. up_proj, X, out = temp_up) gate = torch_nn_functional_silu(gate, inplace = True) @@ -435,7 +434,7 @@ def LlamaAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -654,7 +653,7 @@ def LlamaModel_fast_forward( return_dict: Optional[bool] = None, *args, **kwargs, ) -> Union[Tuple, BaseModelOutputWithPast]: - + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions assert(output_attentions is False) output_hidden_states = ( @@ -689,7 +688,7 @@ def LlamaModel_fast_forward( inputs_embeds = inputs_embeds[:,:self.max_seq_length,:] pass pass - + past_key_values_length = 0 if past_key_values is not None: @@ -776,7 +775,7 @@ def LlamaModel_fast_forward( # Ignore attention_mask if attention_mask is None: padding_mask = None - elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': + elif self.training and os.environ.get("UNSLOTH_KEEP_PADDING", "0") != '1': attention_mask = None padding_mask = None else: @@ -1021,10 +1020,6 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - decoder_device = decoder_layer.self_attn.q_proj.weight.device - X, residual, temp_gate, temp_up, position_ids = move_to_device( - decoder_device, X, residual, temp_gate, temp_up, position_ids - ) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, @@ -1140,7 +1135,7 @@ def _CausalLM_fast_forward( logit_scaling = getattr(self.config, "logit_scale", 0) dtype = lm_head.dtype num_logits_to_keep = max(num_logits_to_keep, logits_to_keep) - + # Move items to same device as lm_head hidden_states = hidden_states.to(lm_head_device) if labels is not None: labels = labels.to(lm_head_device) @@ -1167,7 +1162,7 @@ def _CausalLM_fast_forward( RETURN_LOGITS = os.environ.get("UNSLOTH_RETURN_LOGITS", "0") == "1" # < 1024 Normal Unsloth uses less VRAM! if bsz*q_len <= 1024: RETURN_LOGITS = True - + if not RETURN_LOGITS and HAS_CUT_CROSS_ENTROPY and labels is not None: n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None) @@ -1281,16 +1276,16 @@ def PeftModel_fast_forward( **kwargs, ): is_classification = "Classification" in str(type( self.base_model.model)) - if is_classification: + if is_classification: #causal_mask = causal_mask, return self.base_model( input_ids = input_ids, - attention_mask = attention_mask, - inputs_embeds = inputs_embeds, - labels = labels, + attention_mask = attention_mask, + inputs_embeds = inputs_embeds, + labels = labels, output_attentions = output_attentions, - output_hidden_states = output_hidden_states, - return_dict = return_dict, + output_hidden_states = output_hidden_states, + return_dict = return_dict, **kwargs, ) else: @@ -1344,7 +1339,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # Build here to make `torch.jit.trace` work. for device_idx in range(torch.cuda.device_count()): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) - + # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1462,7 +1457,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= # Build here to make `torch.jit.trace` work. for device_idx in range(torch.cuda.device_count()): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) - + # dummy so that patch_utils doesn't fail for now self.cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1472,7 +1467,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len - + t = torch.arange(self.current_rope_size, device=self.inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.inv_freq) @@ -1595,21 +1590,21 @@ def __init__(self, # Short and long inv_freq self.register_buffer("short_inv_freq", short_inv_freq, persistent = False) self.register_buffer("long_inv_freq", long_inv_freq, persistent = False) - + # Build here to make `torch.jit.trace` work. # Initialize short sequences cache for all devices dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16 t = torch.arange(original_max_position_embeddings, device=self.short_inv_freq.device, dtype=torch.int64).float() freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - + for device_idx in range(torch.cuda.device_count()): device_obj = torch.device(device_idx) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) self.multi_gpu_short_cos_cached[device_idx] = cos_cached self.multi_gpu_short_sin_cached[device_idx] = sin_cached - + # dummy so that patch_utils doesn't fail for now self.short_cos_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) self.short_sin_cached = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.get_default_dtype()) @@ -1621,7 +1616,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and # in FP32. They are applied (multiplied) in FP32 as well. self.current_rope_size = seq_len - + t = torch.arange(self.current_rope_size, device=self.long_inv_freq.device, dtype=torch.int64).float() # Long sequences freqs = torch.outer(t, self.long_inv_freq) @@ -1787,7 +1782,7 @@ def from_pretrained( max_lora_rank = 16, disable_log_stats = False, unsloth_vllm_standby = False, - num_labels = None, + num_labels = None, **kwargs, ): os.environ["UNSLOTH_USE_NEW_MODEL"] = "0" @@ -1859,7 +1854,7 @@ def from_pretrained( if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" model_patcher.pre_patch() - get_statistics() # For debugging - we use a download counter to see if environments are not breaking + get_statistics() # For debugging - we use a download counter to see if environments are not breaking if dtype is None: dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 @@ -1946,7 +1941,7 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit: kwargs["quantization_config"] = bnb_config - + if num_labels is not None: model = AutoModelForSequenceClassification.from_pretrained( model_name, @@ -2044,7 +2039,7 @@ def from_pretrained( except: raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass - + import transformers.trainer items_in_trainer = dir(transformers.trainer) good_items = [] @@ -2391,7 +2386,7 @@ def get_peft_model( ) loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1) pass - + if hasattr(model.config, "quantization_config"): raise ValueError( "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"\ @@ -2521,7 +2516,7 @@ def get_peft_model( is_classification = "Classification" in str(type(model)) # Get LoRA - # + # arguments = dict( r = r, @@ -2547,7 +2542,7 @@ def get_peft_model( input_embeddings_device = model.get_input_embeddings().weight.device if is_classification: output_embeddings_device = model.score.weight.device - else: + else: output_embeddings_device = model.get_output_embeddings().weight.device if use_gradient_checkpointing == "unsloth": @@ -2707,7 +2702,7 @@ def patch_peft_model( # model.peft_config[active_adapter].revision = f"unsloth" pass - from transformers.trainer import Trainer + from transformers.trainer import Trainer if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop": raise RuntimeError("Unsloth: Unsuccessfully patched Trainer! Please file a bug report!") pass @@ -2830,7 +2825,7 @@ def patch_peft_model( internal_model.max_seq_length = max_seq_length internal_model = internal_model.model pass - internal_model.max_seq_length = max_seq_length + internal_model.max_seq_length = max_seq_length # Patch tokenizer to pad to the right internal_model = model From 6c62402d2bf88ffb7078f332747230dc6985f675 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 05:54:07 +0000 Subject: [PATCH 09/17] Move inputs to appropriate decoder device --- unsloth/models/_utils.py | 19 +++++++++++++++++++ unsloth/models/cohere.py | 3 +++ unsloth/models/gemma.py | 4 ++++ unsloth/models/gemma2.py | 6 ++++++ unsloth/models/granite.py | 4 ++++ unsloth/models/llama.py | 4 ++++ 6 files changed, 40 insertions(+) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index e1511df18a..5b1a52a19b 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -773,6 +773,25 @@ def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO accelerate.accelerator.Accelerator.distributed_type = lambda *args, **kwargs: DistributedType.NO pass +# to move multiple tensors to the same device +def move_to_device(target_device, *tensors): + """ + Move multiple tensors to target device if they're not already there. + + Args: + target_device: The target device to move tensors to + *tensors: Variable number of tensors to potentially move + + Returns: + tuple: The tensors on the target device (same objects if already on device, new if moved) + """ + moved_tensors = [] + for tensor in tensors: + if tensor.device != target_device: + moved_tensors.append(tensor.to(target_device)) + else: + moved_tensors.append(tensor) + return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0] import transformers.utils.quantization_config transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__ diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 168824d635..693acdad11 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -418,6 +418,9 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + hidden_states, out_weight, position_ids = move_to_device( + decoder_layer._per_layer_device, hidden_states, out_weight, position_ids + ) residual = hidden_states hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 66bd23af95..165415126c 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -170,6 +170,10 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + hidden_states, out_weight, position_ids = move_to_device( + decoder_layer._per_layer_device, hidden_states, out_weight, position_ids + ) + residual = hidden_states hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) hidden_states, present_key_value = LlamaAttention_fast_forward_inference( diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 7b8866f705..7910daad08 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -423,6 +423,12 @@ def Gemma2Model_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + # For pipeline parallelism, we need to move all tensors to the same device + # note that this movement is once per GPU in PP + hidden_states, out_weight, position_ids = move_to_device( + decoder_layer._per_layer_device, hidden_states, out_weight, position_ids + ) + use_sliding_window = idx % 2 == 0 residual = hidden_states diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index e9a62a7d42..292e7c70c8 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -400,6 +400,10 @@ def GraniteModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + hidden_states, position_ids = move_to_device( + decoder_layer._per_layer_device, hidden_states, position_ids + ) + residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) hidden_states, present_key_value = GraniteAttention_fast_forward_inference( diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index e1bc14bb1d..f73bdef884 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -20,6 +20,7 @@ from ._utils import * from ._utils import patch_unsloth_smart_gradient_checkpointing from ._utils import __version__ +from ._utils import move_to_device from torch.nn.functional import scaled_dot_product_attention from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype @@ -1020,6 +1021,9 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): + X, residual, temp_gate, temp_up, position_ids = move_to_device( + decoder_layer._per_layer_device, X, residual, temp_gate, temp_up, position_ids + ) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( decoder_layer.input_layernorm, From 88a31aee17d72d44f50a0066c9cfe84fbf2bcd62 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 06:43:34 +0000 Subject: [PATCH 10/17] Make device count global variable --- unsloth/__init__.py | 15 +++++++++++++-- unsloth/kernels/utils.py | 22 +++++++++++----------- unsloth/models/gemma.py | 8 ++++---- unsloth/models/llama.py | 34 ++++++++++++++++------------------ unsloth/models/vision.py | 14 ++++++-------- 5 files changed, 50 insertions(+), 43 deletions(-) diff --git a/unsloth/__init__.py b/unsloth/__init__.py index a1a39fd192..4da08da13d 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -22,7 +22,7 @@ already_imported = [mod for mod in critical_modules if mod in sys.modules] # This check is critical because Unsloth optimizes these libraries by modifying -# their code at import time. If they're imported first, the original (slower, +# their code at import time. If they're imported first, the original (slower, # more memory-intensive) implementations will be used instead of Unsloth's # optimized versions, potentially causing OOM errors or slower training. @@ -73,6 +73,17 @@ def get_device_type(): pass DEVICE_TYPE : str = get_device_type() +def get_device_count(): + if DEVICE_TYPE == "cuda": + return torch.cuda.device_count() + elif DEVICE_TYPE == "xpu": + return torch.xpu.device_count() + else: + return 0 +pass + +DEVICE_COUNT : int = get_device_count() + # Reduce VRAM usage by reducing fragmentation # And optimize pinning of memory if DEVICE_TYPE == "cuda" and os.environ.get("UNSLOTH_VLLM_STANDBY", "0")=="0": @@ -237,4 +248,4 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 from .trainer import * # Patch TRL trainers for backwards compatibility -_patch_trl_trainer() \ No newline at end of file +_patch_trl_trainer() diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index 1c65246b31..f5f111c943 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -100,7 +100,7 @@ def torch_gpu_device(device): return nullcontext() # INTEL GPU Specific Logic if DEVICE_TYPE == "xpu": - _gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream + _gpu_getCurrentRawStream = torch._C._xpu_getCurrentRawStream # NVIDIA GPU Default Logic else: _gpu_getCurrentRawStream = torch._C._cuda_getCurrentRawStream @@ -126,7 +126,7 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1) WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) ABSMAX_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) - for k, v in _XPU_STREAMS.items(): + for k, v in _XPU_STREAMS.items(): XPU_STREAMS[k] = v XPU_STREAMS = tuple(XPU_STREAMS) del _XPU_STREAMS @@ -152,16 +152,16 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: # TODO: After adding XPU BNB support, this function should be implemented def cdequantize_blockwise_fp32(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp32 should not be called now.") - + def cdequantize_blockwise_fp16_nf4(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_fp16_nf4 should not be called now.") - + def cdequantize_blockwise_bf16_nf4(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cdequantize_blockwise_bf16_nf4 should not be called now.") - + def cgemm_4bit_inference_naive_fp16(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_fp16 should not be called now.") - + def cgemm_4bit_inference_naive_bf16(*args, **kwargs): raise RuntimeError("XPU BNB support is not implemented yet. cgemm_4bit_inference_naive_bf16 should not be called now.") else: @@ -193,7 +193,7 @@ def get_lora_parameters(proj): adapter = getattr(proj, "active_adapters", None) if adapter is None: adapter = getattr(proj, "active_adapter", ("default")) adapter = adapter[0] - + return ( W, getattr(W, "quant_state", None), @@ -232,7 +232,7 @@ def get_lora_parameters_bias(proj): if DEVICE_TYPE == "xpu" and HAS_XPU_STREAM: @torch.inference_mode def fast_dequantize(W, quant_state = None, out = None, use_global_buffer = False): - # TODO: After adding XPU BNB support, check this function + # TODO: After adding XPU BNB support, check this function if quant_state is None: return W if type(quant_state) is not list: # New quant_state as a class @@ -535,7 +535,7 @@ def fast_gemv(X, W, quant_state, out = None): device = W.device device_index = device.index CUDA_STREAM = CUDA_STREAMS[device_index] - + # assert(dtype == X.dtype) bout = shape[0] @@ -669,7 +669,7 @@ def fast_linear_forward(proj, X, temp_lora = None, out = None): lora_A._fast_lora = lora_A.to(dtype) lora_B._fast_lora = lora_B.to(dtype) pass - + if bsz == 1: out = out.view(out_dim) temp_lora = torch_mv(lora_A._fast_lora, X.ravel(), out = temp_lora) @@ -709,6 +709,6 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None): out.addmm_(XA, B.to(dtype), alpha = s) # out += (X @ A.to(dtype)) @ (s * B.to(dtype)) pass - + return out.view(batch, seq_len, -1) if reshape else out pass diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 165415126c..75bf980473 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -228,11 +228,11 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() - self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() + self.multi_gpu_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_sin_cached = [None]*DEVICE_COUNT # Build here to make `torch.jit.trace` work. - for device in range(torch.cuda.device_count()): + for device in range(DEVICE_COUNT): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device), dtype=torch.get_default_dtype()) # dummy so that patch_utils doesn't fail for now @@ -284,7 +284,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = math.ceil(seq_len / 8192) * 8192 - for device in range(torch.cuda.device_count()): + for device in range(DEVICE_COUNT): self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device), dtype = x.dtype) pass pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 12aba4fb91..58e47fb9f7 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -25,7 +25,7 @@ from transformers import __version__ as transformers_version from unsloth_zoo.utils import Version, _get_dtype from unsloth_zoo.peft_utils import SKIP_QUANTIZATION_MODULES -from unsloth import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT transformers_version = Version(transformers_version) # Transformers moved rotary embeddings out of all attention layers @@ -1357,11 +1357,11 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() - self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() + self.multi_gpu_cos_cached = [None]*NUM_GPUS + self.multi_gpu_sin_cached = [None]*NUM_GPUS # Build here to make `torch.jit.trace` work. - for device_idx in range(torch.cuda.device_count()): + for device_idx in range(NUM_GPUS): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) # dummy so that patch_utils doesn't fail for now @@ -1409,7 +1409,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - for device_idx in range(torch.cuda.device_count()): + for device_idx in range(NUM_GPUS): self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass @@ -1468,8 +1468,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = [None]*torch.cuda.device_count() - self.multi_gpu_sin_cached = [None]*torch.cuda.device_count() + self.multi_gpu_cos_cached = [None]*NUM_GPUS + self.multi_gpu_sin_cached = [None]*NUM_GPUS # Normal Llama-3 RoPE inv_freq = 1.0 / ( @@ -1479,7 +1479,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.register_buffer("inv_freq", inv_freq, persistent = False) # Build here to make `torch.jit.trace` work. - for device_idx in range(torch.cuda.device_count()): + for device_idx in range(NUM_GPUS): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) # dummy so that patch_utils doesn't fail for now @@ -1525,7 +1525,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - for device_idx in range(torch.cuda.device_count()): + for device_idx in range(NUM_GPUS): self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass @@ -1589,10 +1589,10 @@ def __init__(self, self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) - self.multi_gpu_short_cos_cached = [None]*torch.cuda.device_count() - self.multi_gpu_short_sin_cached = [None]*torch.cuda.device_count() - self.multi_gpu_long_cos_cached = [None]*torch.cuda.device_count() - self.multi_gpu_long_sin_cached = [None]*torch.cuda.device_count() + self.multi_gpu_short_cos_cached = [None]*NUM_GPUS + self.multi_gpu_short_sin_cached = [None]*NUM_GPUS + self.multi_gpu_long_cos_cached = [None]*NUM_GPUS + self.multi_gpu_long_sin_cached = [None]*NUM_GPUS # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin @@ -1622,7 +1622,7 @@ def __init__(self, freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - for device_idx in range(torch.cuda.device_count()): + for device_idx in range(NUM_GPUS): device_obj = torch.device(device_idx) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) @@ -1683,7 +1683,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - for device_idx in range(torch.cuda.device_count()): + for device_idx in range(NUM_GPUS): self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass @@ -1839,7 +1839,6 @@ def from_pretrained( gpu_stats = torch.cuda.get_device_properties(0) gpu_version = torch.version.cuda gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}." - num_gpus = torch.cuda.device_count() from importlib.metadata import version as importlib_version try: vllm_version = f" vLLM: {importlib_version('vllm')}." @@ -1847,7 +1846,6 @@ def from_pretrained( elif DEVICE_TYPE == "xpu": gpu_stats = torch.xpu.get_device_properties(0) gpu_version = torch.version.xpu - num_gpus = torch.xpu.device_count() gpu_stats_snippet = f"Intel Toolkit: {gpu_version}." try: vllm_version = f" vLLM: {importlib_version('vllm')}." @@ -1859,7 +1857,7 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_patcher.__name__[4:-5]} patching. Transformers: {transformers_version}.{vllm_version}\n"\ - f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free license: http://github.com/unslothai/unsloth' diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 53c5497424..a358594d85 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -61,7 +61,7 @@ # Old HF Hub versions <= 0.0.25 from huggingface_hub.utils._token import get_token pass -from unsloth import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT __all__ = [ "FastBaseModel", @@ -281,7 +281,6 @@ def from_pretrained( gpu_stats = torch.cuda.get_device_properties(0) gpu_version = torch.version.cuda gpu_stats_snippet = f"CUDA: {gpu_stats.major}.{gpu_stats.minor}. CUDA Toolkit: {gpu_version}." - num_gpus = torch.cuda.device_count() from importlib.metadata import version as importlib_version try: vllm_version = f" vLLM: {importlib_version('vllm')}." @@ -289,7 +288,6 @@ def from_pretrained( elif DEVICE_TYPE == "xpu": gpu_stats = torch.xpu.get_device_properties(0) gpu_version = torch.version.xpu - num_gpus = torch.xpu.device_count() gpu_stats_snippet = f"Intel Toolkit: {gpu_version}." # TODO: After adding vLLM support for XPU, changed this @@ -306,11 +304,11 @@ def from_pretrained( statistics = \ f"==((====))== Unsloth {__version__}: Fast {model_type_arch.title()} patching. Transformers: {transformers_version}.{vllm_version}\n"\ - f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {num_gpus}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ + f" {chr(92)}{chr(92)} /| {gpu_stats.name}. Num GPUs = {DEVICE_COUNT}. Max memory: {max_memory} GB. Platform: {platform_system}.\n"\ f"O^O/ {chr(92)}_/ {chr(92)} Torch: {torch.__version__}. {gpu_stats_snippet} Triton: {triton_version}\n"\ f"{chr(92)} / Bfloat16 = {str(SUPPORTS_BFLOAT16).upper()}. FA [Xformers = {xformers_version}. FA2 = {HAS_FLASH_ATTENTION}]\n"\ f' "-____-" Free license: http://github.com/unslothai/unsloth' - + print(statistics) # Warn about fast transfers @@ -325,7 +323,7 @@ def from_pretrained( pass if old_hf_transfer != "0": os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" - get_statistics() # For debugging - we use a download counter to see if environments are not breaking + get_statistics() # For debugging - we use a download counter to see if environments are not breaking if dtype is None: dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16 @@ -604,7 +602,7 @@ def get_peft_model( else: assert(type(target_modules) in (list, tuple,)) pass - + # Clear deleted GPU items for _ in range(3): gc.collect() @@ -678,7 +676,7 @@ def post_patch_model( float32_mixed_precision = float32_mixed_precision, ) - from transformers.trainer import Trainer + from transformers.trainer import Trainer if Trainer._inner_training_loop.__name__ != "_fast_inner_training_loop" and trust_remote_code == False: raise RuntimeError('Unsloth: Unsuccessfully patched inner_training_loop') pass From f1c6bd6ad8a199e100edc2d7569a0c35903296e9 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 06:54:25 +0000 Subject: [PATCH 11/17] Cleanup RoPE device code --- unsloth/models/gemma.py | 6 ++++-- unsloth/models/llama.py | 27 +++++++++++++++------------ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 75bf980473..300b00a470 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -268,9 +268,11 @@ def forward(self, x, position_ids=None, seq_len=None): if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + device_index = x.device.index + return ( - self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype=x.dtype), - self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype=x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype=x.dtype), + self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype=x.dtype), ) pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 58e47fb9f7..18ee5a8fe3 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1393,16 +1393,17 @@ def forward(self, x, position_ids=None, seq_len=None): if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + device_index = x.device.index return ( - self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype = x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device] + return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index] pass def extend_rope_embedding(self, x, seq_len): @@ -1508,17 +1509,17 @@ def forward(self, x, position_ids=None, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - + device_index = x.device.index return ( - self.multi_gpu_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype = x.dtype), ) pass def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device.index if hasattr(device, 'index') else device], self.multi_gpu_sin_cached[device.index if hasattr(device, 'index') else device] + return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index] pass def extend_rope_embedding(self, x, seq_len): @@ -1657,15 +1658,17 @@ def forward(self, x, position_ids=None, seq_len=None): if seq_len is not None and seq_len > self.current_rope_size: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + device_index = x.device.index + if seq_len is not None and seq_len < self.original_max_position_embeddings: return ( - self.multi_gpu_short_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_short_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_cos_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_sin_cached[device_index][:seq_len].to(dtype = x.dtype), ) else: return ( - self.multi_gpu_long_cos_cached[x.device.index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_long_sin_cached[x.device.index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_cos_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_sin_cached[device_index][:seq_len].to(dtype = x.dtype), ) pass pass @@ -1673,7 +1676,7 @@ def forward(self, x, position_ids=None, seq_len=None): def get_cached(self, seq_len = None, device = None): if device is None: device = torch.cuda.current_device() - device_index = device.index if hasattr(device, 'index') else device + device_index = device.index if seq_len is not None and seq_len < self.original_max_position_embeddings: return self.multi_gpu_short_cos_cached[device_index], self.multi_gpu_short_sin_cached[device_index] return self.multi_gpu_long_cos_cached[device_index], self.multi_gpu_long_sin_cached[device_index] From 704f8ec60262802eeb5d74f17650f4a0f3450c28 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 07:03:20 +0000 Subject: [PATCH 12/17] Fixup num_gpu to device count --- unsloth/models/llama.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 18ee5a8fe3..8cbaf09c8b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1357,11 +1357,11 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = [None]*NUM_GPUS - self.multi_gpu_sin_cached = [None]*NUM_GPUS + self.multi_gpu_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_sin_cached = [None]*DEVICE_COUNT # Build here to make `torch.jit.trace` work. - for device_idx in range(NUM_GPUS): + for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) # dummy so that patch_utils doesn't fail for now @@ -1410,7 +1410,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - for device_idx in range(NUM_GPUS): + for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass @@ -1469,8 +1469,8 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(4 * 8192, self.max_position_embeddings) - self.multi_gpu_cos_cached = [None]*NUM_GPUS - self.multi_gpu_sin_cached = [None]*NUM_GPUS + self.multi_gpu_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_sin_cached = [None]*DEVICE_COUNT # Normal Llama-3 RoPE inv_freq = 1.0 / ( @@ -1480,7 +1480,7 @@ def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device= self.register_buffer("inv_freq", inv_freq, persistent = False) # Build here to make `torch.jit.trace` work. - for device_idx in range(NUM_GPUS): + for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache(seq_len=self.current_rope_size, device=torch.device(device_idx), dtype=torch.get_default_dtype()) # dummy so that patch_utils doesn't fail for now @@ -1526,7 +1526,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - for device_idx in range(NUM_GPUS): + for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass @@ -1590,10 +1590,10 @@ def __init__(self, self.base = base # Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this self.current_rope_size = min(original_max_position_embeddings, self.max_position_embeddings) - self.multi_gpu_short_cos_cached = [None]*NUM_GPUS - self.multi_gpu_short_sin_cached = [None]*NUM_GPUS - self.multi_gpu_long_cos_cached = [None]*NUM_GPUS - self.multi_gpu_long_sin_cached = [None]*NUM_GPUS + self.multi_gpu_short_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_short_sin_cached = [None]*DEVICE_COUNT + self.multi_gpu_long_cos_cached = [None]*DEVICE_COUNT + self.multi_gpu_long_sin_cached = [None]*DEVICE_COUNT # Long RoPE similar to RoPE except short sequences have 1 cos / sin # and long sequences have another cos / sin @@ -1623,7 +1623,7 @@ def __init__(self, freqs = torch.outer(t, self.short_inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - for device_idx in range(NUM_GPUS): + for device_idx in range(DEVICE_COUNT): device_obj = torch.device(device_idx) cos_cached = (emb.cos() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) sin_cached = (emb.sin() * self.scaling_factor).to(dtype=dtype, device=device_obj, non_blocking=True) @@ -1686,7 +1686,7 @@ def extend_rope_embedding(self, x, seq_len): if seq_len <= self.current_rope_size: return # Iteratively grow by increments of 8192 self.current_rope_size = ((seq_len // 8192) + ((seq_len % 8192) != 0)) * 8192 - for device_idx in range(NUM_GPUS): + for device_idx in range(DEVICE_COUNT): self._set_cos_sin_cache(self.current_rope_size, device = torch.device(device_idx), dtype = x.dtype) pass pass From 21b65b0172e6dca2be7e0fd0b311a2003ae63bfb Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 07:16:27 +0000 Subject: [PATCH 13/17] Cleanup device counts --- unsloth/kernels/utils.py | 15 ++++++++------- unsloth/models/_utils.py | 13 +++---------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index f5f111c943..645319d423 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -18,7 +18,7 @@ next_power_of_2 = triton.next_power_of_2 import functools from typing import Optional -from unsloth import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT # torch.cuda.amp.custom_fwd is deprecated >= 2.4 import torch @@ -89,10 +89,11 @@ def get_ptr(x: Optional[torch.Tensor]): get_ptr = bnb.functional.get_ptr -if DEVICE_TYPE == "cuda" and torch.cuda.device_count() > 1: - torch_gpu_device = torch.cuda.device -elif DEVICE_TYPE == "xpu" and torch.xpu.device_count() > 1: - torch_gpu_device = torch.xpu.device +if DEVICE_COUNT > 1: + if DEVICE_TYPE == "cuda": + torch_gpu_device = torch.cuda.device + elif DEVICE_TYPE == "xpu": + torch_gpu_device = torch.xpu.device else: from contextlib import nullcontext def torch_gpu_device(device): return nullcontext() @@ -121,7 +122,7 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: if DEVICE_TYPE == "xpu": _XPU_STREAMS = { (index := torch.xpu.device(i).idx) : ctypes.c_void_p(torch._C._xpu_getCurrentRawStream(index)) - for i in range(torch.xpu.device_count()) + for i in range(DEVICE_COUNT) } XPU_STREAMS = [None] * (max(_XPU_STREAMS.keys()) + 1) WEIGHT_BUFFERS = [None] * (max(_XPU_STREAMS.keys()) + 1) @@ -134,7 +135,7 @@ def _get_tensor_stream(tensor: torch_Tensor) -> c_void_p: # NVIDIA GPU Default Logic _CUDA_STREAMS = { (index := torch.cuda.device(i).idx) : ctypes.c_void_p(torch._C._cuda_getCurrentRawStream(index)) - for i in range(torch.cuda.device_count()) + for i in range(DEVICE_COUNT) } CUDA_STREAMS = [None] * (max(_CUDA_STREAMS.keys()) + 1) WEIGHT_BUFFERS = [None] * (max(_CUDA_STREAMS.keys()) + 1) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index fb6f407648..71e38871e8 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -77,7 +77,7 @@ import re import warnings, subprocess, re, inspect, psutil, os, math from unsloth_zoo.utils import Version -from unsloth_zoo import DEVICE_TYPE +from unsloth import DEVICE_TYPE, DEVICE_COUNT from unsloth_zoo.tokenizer_utils import ( patch_tokenizer as _patch_tokenizer, @@ -142,12 +142,6 @@ import logging logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1) -def get_device_num(): - if DEVICE_TYPE == "xpu": - return torch.xpu.device_count() - else: - return torch.cuda.device_count() - # Ignore logging messages class HideLoggingMessage(logging.Filter): __slots__ = "text", @@ -746,8 +740,7 @@ def get_statistics(): pass pass try: - devices = get_device_num() - _get_statistics(f"{devices if devices <= 8 else 9}") + _get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}") except: pass if disabled: enable_progress_bars() @@ -773,7 +766,7 @@ def get_statistics(): ) exec(BitsAndBytesConfig__init__, globals()) -if get_device_num() == 1: +if DEVICE_COUNT == 1: from accelerate.utils.dataclasses import DistributedType def _prepare_backend(self, *args, **kwargs): return None, DistributedType.NO import accelerate.state From 53759c8a21af8f7434cb2e1a544ceea49e359ae3 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 09:14:23 +0000 Subject: [PATCH 14/17] Use device index for RoPE get_cache --- unsloth/models/_utils.py | 7 +++++++ unsloth/models/cohere.py | 13 ++++++------ unsloth/models/falcon_h1.py | 17 ++++++++-------- unsloth/models/gemma.py | 25 ++++++++++++----------- unsloth/models/gemma2.py | 26 +++++++++++++----------- unsloth/models/granite.py | 2 +- unsloth/models/llama.py | 40 ++++++++++++++++++------------------- unsloth/models/mistral.py | 18 ++++++++--------- unsloth/models/qwen3.py | 7 ++++--- 9 files changed, 84 insertions(+), 71 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 71e38871e8..4248a36c58 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -786,6 +786,13 @@ def move_to_device(target_device, *tensors): Returns: tuple: The tensors on the target device (same objects if already on device, new if moved) """ + if isinstance(target_device, int) or isinstance(target_device, str): + target_device = torch.device(target_device) + elif isinstance(target_device, torch.device): + pass + else: + raise ValueError(f"Invalid target device: {target_device}") + pass moved_tensors = [] for tensor in tensors: if tensor.device != target_device: diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 29e2b30d24..4378d9439e 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -321,7 +321,7 @@ def CohereAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -398,7 +398,7 @@ def CohereModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + out_weights = [torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)] input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -418,11 +418,12 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - hidden_states, out_weight, position_ids = move_to_device( - decoder_layer._per_layer_device, hidden_states, out_weight, position_ids + device_index = decoder_layer._per_layer_device.index + hidden_states, position_ids = move_to_device( + device_index, hidden_states, position_ids ) residual = hidden_states - hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weights[device_index]) hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, @@ -439,7 +440,7 @@ def CohereModel_fast_forward_inference( next_decoder_cache.append(present_key_value) pass - hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight) + hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weights[device_index]) return BaseModelOutputWithPast( last_hidden_state = hidden_states, diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 58e8c94d9d..2cbb78f8ad 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -69,7 +69,7 @@ def FalconH1Attention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -110,12 +110,13 @@ def FalconH1Attention_fast_forward( # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) + device_index = Q.device.index if position_ids is None: # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len, device=Q.device) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) else: - cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device=Q.device) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -245,14 +246,14 @@ def FalconH1Attention_fast_forward_inference( self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = device) self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = device) self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = device) - + # Mistral Nemo 12b has weird dimensions if attention_size != hidden_size: self.temp_O = torch.empty((1, bsz, hidden_size), dtype = dtype, device = device) else: self.temp_O = self.temp_QA[1][:,:,:hidden_size] pass - + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = device) self.scalar = 1.0 / math_sqrt(self.head_dim) self.half_head_dim = head_dim // 2 @@ -280,7 +281,7 @@ def FalconH1Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -298,7 +299,7 @@ def FalconH1Attention_fast_forward_inference( RH_K[:,:,:,:h].neg_() #torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) Kn *= cos Kn.addcmul_(RH_K, sin) - + # New KV cache # Kn = torch.cat([K1, Kn], dim = 2) # Vn = torch.cat([V1, Vn], dim = 2) @@ -580,7 +581,7 @@ def _fast_prepare_inputs_for_generation( **kwargs,): # Overwitten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache` empty_past_kv = past_key_values is None - + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 300b00a470..5d37c9fc4e 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -149,7 +149,7 @@ def GemmaModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + out_weights = [torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)] input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -170,12 +170,13 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - hidden_states, out_weight, position_ids = move_to_device( - decoder_layer._per_layer_device, hidden_states, out_weight, position_ids + device_index = decoder_layer._per_layer_device.index + hidden_states, position_ids = move_to_device( + device_index, hidden_states, position_ids ) residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weights[device_index]) hidden_states, present_key_value = LlamaAttention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, @@ -187,13 +188,13 @@ def GemmaModel_fast_forward_inference( hidden_states += residual residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weights[device_index]) hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) hidden_states += residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weights[device_index]) return BaseModelOutputWithPast( last_hidden_state = hidden_states, @@ -271,15 +272,15 @@ def forward(self, x, position_ids=None, seq_len=None): device_index = x.device.index return ( - self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype=x.dtype), - self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype=x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len], + self.multi_gpu_sin_cached[device_index][:seq_len], ) pass - def get_cached(self, seq_len = None, device = None): - if device is None: - device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index] + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 7910daad08..b6bdf940bd 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -113,12 +113,13 @@ def Gemma2Attention_fast_forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] + device_index = Q.device.index if position_ids is None: - cos = self.rotary_emb.multi_gpu_cos_cached[Q.device.index] - sin = self.rotary_emb.multi_gpu_sin_cached[Q.device.index] + cos = self.rotary_emb.multi_gpu_cos_cached[device_index] + sin = self.rotary_emb.multi_gpu_sin_cached[device_index] Q, K = fast_rope_embedding(Q, K, cos, sin) else: - cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, device_index) Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) pass @@ -307,7 +308,7 @@ def Gemma2Attention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -387,7 +388,7 @@ def Gemma2Model_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0") + out_weights = [torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)] input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -425,14 +426,15 @@ def Gemma2Model_fast_forward_inference( # For pipeline parallelism, we need to move all tensors to the same device # note that this movement is once per GPU in PP - hidden_states, out_weight, position_ids = move_to_device( - decoder_layer._per_layer_device, hidden_states, out_weight, position_ids + device_index = decoder_layer._per_layer_device.index + hidden_states, position_ids = move_to_device( + device_index, hidden_states, position_ids ) use_sliding_window = idx % 2 == 0 residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weights[device_index]) hidden_states, present_key_value = Gemma2Attention_fast_forward_inference( decoder_layer.self_attn, hidden_states = hidden_states, @@ -442,18 +444,18 @@ def Gemma2Model_fast_forward_inference( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), use_sliding_window = use_sliding_window, ) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weights[device_index]) hidden_states += residual residual = hidden_states - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weights[device_index]) hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weights[device_index]) hidden_states += residual next_decoder_cache.append(present_key_value) pass - hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight) + hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weights[device_index]) return BaseModelOutputWithPast( last_hidden_state = hidden_states, diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index a10ba1c7d1..d9d90611a3 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -395,7 +395,7 @@ def GraniteModel_fast_forward_inference( attention_mask = None pass - position_embeddings = self.model.rotary_emb.get_cached(seq_len = self.max_seq_length, device = hidden_states.device) + position_embeddings = self.model.rotary_emb.get_cached(self.max_seq_length, hidden_states.device.index) next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 8cbaf09c8b..15d4652c69 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -275,7 +275,7 @@ def LlamaAttention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim @@ -489,7 +489,7 @@ def LlamaAttention_fast_forward( # cos, sin = rotary_emb.get_cached(kv_seq_len, device = Q.device) # else: # cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) - cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) + cos, sin = rotary_emb.get_cached(kv_seq_len, Q.device.index) # Q, K = ( # fast_rope_embedding(Q, K, cos, sin) @@ -913,7 +913,7 @@ def LlamaModel_fast_forward( # Also, transformers 4.45.0 supports granite but with the attention refactor (it always had the refactor) # unsloth's check for granite too has "version >= 4.45.0 (rightly so)". # so let granite always use the attention refactor implementation. - position_embeddings = self.rotary_emb.get_cached(seq_len = self.config.max_position_embeddings, device = hidden_states.device) + position_embeddings = self.rotary_emb.get_cached(self.config.max_position_embeddings, hidden_states.device.index) else: position_embeddings = None @@ -1023,7 +1023,7 @@ def LlamaModel_fast_forward_inference_custom( XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE}:0") temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE}:0") - temp_gate, temp_up = temp_mlp[0], temp_mlp[1] + temp_gates, temp_ups = [temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)], [temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT)] seq_len = past_key_values[0][0].shape[-2] if bsz != 1: @@ -1041,8 +1041,9 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - X, residual, temp_gate, temp_up, position_ids = move_to_device( - decoder_layer._per_layer_device, X, residual, temp_gate, temp_up, position_ids + device_index = decoder_layer._per_layer_device.index + X, residual, position_ids = move_to_device( + device_index, X, residual, position_ids ) residual.copy_(X) # residual = X X = fast_rms_layernorm_inference( @@ -1073,8 +1074,8 @@ def LlamaModel_fast_forward_inference_custom( X = mlp_fast_forward_inference( decoder_layer.mlp, X, - temp_gate = temp_gate, - temp_up = temp_up, + temp_gate = temp_gates[device_index], + temp_up = temp_ups[device_index], ) X += residual @@ -1400,10 +1401,10 @@ def forward(self, x, position_ids=None, seq_len=None): ) pass - def get_cached(self, seq_len = None, device = None): - if device is None: - device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index] + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): @@ -1516,10 +1517,10 @@ def forward(self, x, position_ids=None, seq_len=None): ) pass - def get_cached(self, seq_len = None, device = None): - if device is None: - device = torch.cuda.current_device() - return self.multi_gpu_cos_cached[device.index], self.multi_gpu_sin_cached[device.index] + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() + return self.multi_gpu_cos_cached[device_index], self.multi_gpu_sin_cached[device_index] pass def extend_rope_embedding(self, x, seq_len): @@ -1673,10 +1674,9 @@ def forward(self, x, position_ids=None, seq_len=None): pass pass - def get_cached(self, seq_len = None, device = None): - if device is None: - device = torch.cuda.current_device() - device_index = device.index + def get_cached(self, seq_len = None, device_index = None): + if device_index is None: + device_index = torch.cuda.current_device() if seq_len is not None and seq_len < self.original_max_position_embeddings: return self.multi_gpu_short_cos_cached[device_index], self.multi_gpu_short_sin_cached[device_index] return self.multi_gpu_long_cos_cached[device_index], self.multi_gpu_long_sin_cached[device_index] diff --git a/unsloth/models/mistral.py b/unsloth/models/mistral.py index 6fc7bbafd2..68d4ba43fb 100644 --- a/unsloth/models/mistral.py +++ b/unsloth/models/mistral.py @@ -51,7 +51,7 @@ def MistralAttention_fast_forward( position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, *args, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - + # Clear inference if hasattr(self, "paged_attention"): del self.paged_attention_K @@ -83,7 +83,7 @@ def MistralAttention_fast_forward( # Extend RoPE dynamically to fit in VRAM self.rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) - cos, sin = self.rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Q.device.index) if position_ids is None: Q, K = fast_rope_embedding(Q, K, cos, sin) else: @@ -160,7 +160,7 @@ def MistralAttention_fast_forward( # Go back to (batch_size, seq_len, n_heads, head_dim) A = A.transpose(1, 2).contiguous() pass - + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) attn_output = self.apply_o(self, attn_output) attn_weights = None @@ -199,7 +199,7 @@ def MistralForCausalLM_fast_forward( causal_mask = xformers.attn_bias.BlockDiagonalCausalMask\ .from_seqlens([q_len]*bsz)\ .make_local_attention(window_size = sliding_window) - + elif not HAS_XFORMERS and attention_mask is None: if sliding_window is None or sliding_window == "null" or sliding_window <= 0 or q_len <= sliding_window: # Fully causal mask @@ -210,10 +210,10 @@ def MistralForCausalLM_fast_forward( # Sliding window attention q_indices = torch.arange(q_len, device=input_ids.device).view(-1, 1) k_indices = torch.arange(q_len, device=input_ids.device).view(1, -1) - + causal_bool_mask = k_indices <= q_indices window_bool_mask = (q_indices - k_indices) < sliding_window - + mask = torch.where(causal_bool_mask & window_bool_mask, 0.0, -torch.inf) attention_mask = mask[None, None, :, :].expand(bsz, 1, q_len, q_len) @@ -256,7 +256,7 @@ def MistralForCausalLM_fast_forward( bsz, q_len, hd = hidden_states.shape lm_head = self.lm_head.weight lm_head_device = lm_head.device - + # Move items to same device as lm_head hidden_states = hidden_states.to(lm_head_device) if labels is not None: labels = labels.to(lm_head_device) @@ -299,7 +299,7 @@ def MistralForCausalLM_fast_forward( if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output - + output = CausalLMOutputWithPast( loss = loss, logits = EMPTY_LOGITS, @@ -388,7 +388,7 @@ def pre_patch(): MistralForCausalLM .forward = MistralForCausalLM_fast_forward PeftModelForCausalLM .forward = PeftModel_fast_forward fix_prepare_inputs_for_generation(MistralForCausalLM) - + # Solves https://github.com/unslothai/unsloth/issues/168 # Static KV Cache was introduced in 4.38.0, causing training to be much slower. # Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings. diff --git a/unsloth/models/qwen3.py b/unsloth/models/qwen3.py index 38cc369087..b20f22dab1 100644 --- a/unsloth/models/qwen3.py +++ b/unsloth/models/qwen3.py @@ -111,12 +111,13 @@ def Qwen3Attention_fast_forward( # Extend RoPE dynamically to fit in VRA rotary_emb = self.rotary_emb rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len) + device_index = Q.device.index if position_ids is None: # Useful for LongRoPE - cos, sin = rotary_emb.get_cached(kv_seq_len) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) else: - cos, sin = rotary_emb.get_cached(seq_len = kv_seq_len, device = Q.device) + cos, sin = rotary_emb.get_cached(kv_seq_len, device_index) Q, K = fast_rope_embedding(Q, K, cos, sin) if past_key_value is not None: @@ -297,7 +298,7 @@ def Qwen3Attention_fast_forward_inference( # Need to do it prior 2 steps before hitting full on short KV cache # or else error self.rotary_emb.extend_rope_embedding(Vn, seq_len + 2) - cos, sin = self.rotary_emb.get_cached(kv_seq_len, device = Qn.device) + cos, sin = self.rotary_emb.get_cached(kv_seq_len, Qn.device.index) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) h = self.half_head_dim From dac2ae8cb25804601bbeead78a57b4dd982225bf Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 09:27:09 +0000 Subject: [PATCH 15/17] Donot typecast --- unsloth/models/llama.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index d217a90e72..1f9e6baedc 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1395,8 +1395,8 @@ def forward(self, x, position_ids=None, seq_len=None): device_index = x.device.index return ( - self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len], + self.multi_gpu_sin_cached[device_index][:seq_len], ) pass @@ -1511,8 +1511,8 @@ def forward(self, x, position_ids=None, seq_len=None): self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) device_index = x.device.index return ( - self.multi_gpu_cos_cached[device_index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_sin_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_cos_cached[device_index][:seq_len], + self.multi_gpu_sin_cached[device_index][:seq_len], ) pass @@ -1662,13 +1662,13 @@ def forward(self, x, position_ids=None, seq_len=None): if seq_len is not None and seq_len < self.original_max_position_embeddings: return ( - self.multi_gpu_short_cos_cached[device_index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_short_sin_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_short_cos_cached[device_index][:seq_len], + self.multi_gpu_short_sin_cached[device_index][:seq_len], ) else: return ( - self.multi_gpu_long_cos_cached[device_index][:seq_len].to(dtype = x.dtype), - self.multi_gpu_long_sin_cached[device_index][:seq_len].to(dtype = x.dtype), + self.multi_gpu_long_cos_cached[device_index][:seq_len], + self.multi_gpu_long_sin_cached[device_index][:seq_len], ) pass pass From 464df7cb7220b80aefce45186a68c905d1f5e505 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 10:00:26 +0000 Subject: [PATCH 16/17] Use tuple instead of list for tensors. Use device index directly --- unsloth/models/cohere.py | 4 ++-- unsloth/models/gemma.py | 4 ++-- unsloth/models/gemma2.py | 4 ++-- unsloth/models/granite.py | 4 ++-- unsloth/models/llama.py | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 4378d9439e..d4691fb5d2 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -398,7 +398,7 @@ def CohereModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weights = [torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)] + out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)) input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -418,7 +418,7 @@ def CohereModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - device_index = decoder_layer._per_layer_device.index + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) hidden_states, position_ids = move_to_device( device_index, hidden_states, position_ids ) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 5d37c9fc4e..e43b205ecd 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -149,7 +149,7 @@ def GemmaModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weights = [torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)] + out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)) input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -170,7 +170,7 @@ def GemmaModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - device_index = decoder_layer._per_layer_device.index + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) hidden_states, position_ids = move_to_device( device_index, hidden_states, position_ids ) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index b6bdf940bd..5597995b05 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -388,7 +388,7 @@ def Gemma2Model_fast_forward_inference( position_ids, attention_mask = None, ): - out_weights = [torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)] + out_weights = tuple(torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = torch.device(x)) for x in range(DEVICE_COUNT)) input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype) @@ -426,7 +426,7 @@ def Gemma2Model_fast_forward_inference( # For pipeline parallelism, we need to move all tensors to the same device # note that this movement is once per GPU in PP - device_index = decoder_layer._per_layer_device.index + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) hidden_states, position_ids = move_to_device( device_index, hidden_states, position_ids ) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index d9d90611a3..a3d79c8333 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -399,9 +399,9 @@ def GraniteModel_fast_forward_inference( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) hidden_states, position_ids = move_to_device( - decoder_layer._per_layer_device, hidden_states, position_ids + device_index, hidden_states, position_ids ) residual = hidden_states diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 1f9e6baedc..db0e8843cf 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1023,7 +1023,7 @@ def LlamaModel_fast_forward_inference_custom( XX, XX2 = _XX[0], _XX[1] variance = torch.empty((bsz, q_len, 1), dtype = torch.float32, device = f"{DEVICE_TYPE}:0") temp_mlp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = f"{DEVICE_TYPE}:0") - temp_gates, temp_ups = [temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)], [temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT)] + temp_gates, temp_ups = tuple(temp_mlp[0].to(torch.device(x)) for x in range(DEVICE_COUNT)), tuple(temp_mlp[1].to(torch.device(x)) for x in range(DEVICE_COUNT)) seq_len = past_key_values[0][0].shape[-2] if bsz != 1: @@ -1041,7 +1041,7 @@ def LlamaModel_fast_forward_inference_custom( next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): - device_index = decoder_layer._per_layer_device.index + device_index = getattr(decoder_layer, "_per_layer_device_index", 0) X, residual, position_ids = move_to_device( device_index, X, residual, position_ids ) From da2bf8401a5b7530eca2d9d1f999c10da5055883 Mon Sep 17 00:00:00 2001 From: Datta Nimmaturi Date: Thu, 10 Jul 2025 10:08:16 +0000 Subject: [PATCH 17/17] fixup move to device logic --- unsloth/models/_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 4248a36c58..c6ff7b6260 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -786,7 +786,10 @@ def move_to_device(target_device, *tensors): Returns: tuple: The tensors on the target device (same objects if already on device, new if moved) """ - if isinstance(target_device, int) or isinstance(target_device, str): + if isinstance(target_device, int): + target_device = torch.device(target_device) + elif isinstance(target_device, str): + # if string we expect it to be a device name like "cuda:0" target_device = torch.device(target_device) elif isinstance(target_device, torch.device): pass