Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions unsloth/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def CohereAttention_fast_forward(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

Expand Down Expand Up @@ -112,12 +113,11 @@ def CohereAttention_fast_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = position_embeddings
if position_ids is None:
cos = self.rotary_emb.cos_cached
sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
else:
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
cos, sin = cos[position_ids], sin[position_ids]
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass

Expand Down Expand Up @@ -190,6 +190,7 @@ def CohereDecoderLayer_fast_forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
):
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
Expand Down
41 changes: 26 additions & 15 deletions unsloth/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def LlamaAttention_fast_forward(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

Expand Down Expand Up @@ -368,20 +369,24 @@ def LlamaAttention_fast_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

# Extend RoPE dynamically to fit in VRAM
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len = kv_seq_len)

if position_ids is None:
# Useful for LongRoPE
cos, sin = rotary_emb.get_cached(kv_seq_len)
# cos = self.rotary_emb.cos_cached
# sin = self.rotary_emb.sin_cached
Q, K = fast_rope_embedding(Q, K, cos, sin)
if position_embeddings:
cos, sin = position_embeddings
else:
cos, sin = rotary_emb(V, seq_len = kv_seq_len)
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
pass
# Extend RoPE dynamically to fit in VRA
rotary_emb = self.rotary_emb
rotary_emb.extend_rope_embedding(V, seq_len=kv_seq_len)

if position_ids is None:
# Useful for LongRoPE
cos, sin = rotary_emb.get_cached(kv_seq_len)
else:
cos, sin = rotary_emb(V, seq_len=kv_seq_len)

Q, K = (
fast_rope_embedding(Q, K, cos, sin)
if position_ids is None
else inplace_rope_embedding(Q, K, cos, sin, position_ids)
)

if past_key_value is not None:
K = torch.cat([past_key_value[0], K], dim = 2)
Expand Down Expand Up @@ -452,6 +457,7 @@ def LlamaDecoderLayer_fast_forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -479,6 +485,7 @@ def LlamaDecoderLayer_fast_forward(
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings = position_embeddings,
)
hidden_states += residual

Expand All @@ -499,6 +506,7 @@ def LlamaDecoderLayer_fast_forward(
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
position_embeddings = position_embeddings,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -777,8 +785,11 @@ def LlamaModel_fast_forward(
pass


if IS_GRANITE:
position_embeddings = self.rotary_emb(hidden_states, position_ids, self.max_position_embeddings)
if transformers_version > "4.47.1" and hasattr(self,'rotary_emb'):
# Transformers main has made it mandatory to pass position_embeddings
# https://github.com/huggingface/transformers/pull/34858
position_embeddings = self.rotary_emb(hidden_states, position_ids, self.config.max_position_embeddings)
print(f'position_embeddings: {position_embeddings}')
else:
position_embeddings = None

Expand Down
1 change: 1 addition & 0 deletions unsloth/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def MistralAttention_fast_forward(
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args, **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

Expand Down