From 17f629289c0246308357ff6f16e03d1e1635cbea Mon Sep 17 00:00:00 2001 From: datta0 Date: Sun, 27 Oct 2024 12:39:00 +0000 Subject: [PATCH 1/5] [WIP] Support for Granite --- unsloth/kernels/flex_attention.py | 80 +++-- unsloth/models/__init__.py | 1 + unsloth/models/_utils.py | 2 +- unsloth/models/gemma2.py | 2 +- unsloth/models/granite.py | 545 ++++++++++++++++++++++++++++++ unsloth/models/llama.py | 18 +- unsloth/models/loader.py | 6 +- 7 files changed, 623 insertions(+), 31 deletions(-) create mode 100644 unsloth/models/granite.py diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 08426b69e0..7edf88e406 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -41,7 +41,7 @@ # Logit softcapping @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) - def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): + def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale=1, is_gemma2=False): n_heads = self.num_heads head_dim = self.head_dim n_kv_heads = self.num_key_value_heads @@ -53,14 +53,22 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): K = K.reshape(bsz, n_heads, q_len, head_dim) V = V.reshape(bsz, n_heads, q_len, head_dim) - # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below - # We default to using the config file itself - # s = self.config.hidden_size // self.config.num_attention_heads - s = self.config.query_pre_attn_scalar - t = self.config.attn_logit_softcapping + if is_gemma2: + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below + # We default to using the config file itself + # s = self.config.hidden_size // self.config.num_attention_heads + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping + + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly + A = torch_matmul(Q, K.transpose(2, 3)) + + # Logit softcapping + A /= t; torch_tanh(A, out = A); A *= t; + else: + A = torch_matmul(Q, K.transpose(2, 3)) * scale - Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly A = torch.matmul(Q, K.transpose(2, 3)) A = t * torch.tanh(A / t) # Logit softcapping A += causal_mask[:q_len, :q_len] @@ -86,6 +94,9 @@ def tanh_softcap(x, b, h, q_idx, kv_idx): return t * torch.tanh(x / t) return tanh_softcap pass + def noop(score, b, h, q_idx, kv_idx): + return score + pass def causal_masker(b, h, q_idx, kv_idx): return q_idx >= kv_idx pass @@ -120,20 +131,31 @@ def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_win pass @functools.lru_cache - def flex_attention(s, t): - scale = 1.0 / math.sqrt(s) - score_mod = generate_tanh_softcap(t) + def flex_attention(s=1,t=1, is_gemma2=False): + if is_gemma2: + scale = 1.0 / math.sqrt(s) + enable_gqa = True + score_mod = generate_tanh_softcap(s, t) + else: + # mostly for granite + scale = s + enable_gqa = False + score_mod = noop() return functools.partial( - _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True, + _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = enable_gqa, ) pass - def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): + def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale=1, is_gemma2=False): n_heads = self.num_heads head_dim = self.head_dim - s = self.config.query_pre_attn_scalar - t = self.config.attn_logit_softcapping - fx = flex_attention(s, t) + if is_gemma2: + s = 1.0/ math.sqrt(self.config.query_pre_attn_scalar) + t = self.config.attn_logit_softcapping + else: + s = scale + t = 1.0 + fx = flex_attention(s,t, is_gemma2) A = fx(query = Q, key = K, value = V, block_mask = causal_mask) A = A.transpose(1, 2).contiguous() A = A.reshape(bsz, q_len, n_heads*head_dim) @@ -145,7 +167,7 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): torch_matmul = torch.matmul torch_tanh = torch.tanh torch_nn_functional_softmax = torch.nn.functional.softmax -def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): +def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale, is_gemma2=True): n_heads = self.num_heads head_dim = self.head_dim n_kv_heads = self.num_key_value_heads @@ -157,18 +179,22 @@ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len) K = K.reshape(bsz, n_heads, q_len, head_dim) V = V.reshape(bsz, n_heads, q_len, head_dim) - # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below - # We default to using the config file itself - # s = self.config.hidden_size // self.config.num_attention_heads - s = self.config.query_pre_attn_scalar - t = self.config.attn_logit_softcapping + if is_gemma2: + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below + # We default to using the config file itself + # s = self.config.hidden_size // self.config.num_attention_heads + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping - Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly - A = torch_matmul(Q, K.transpose(2, 3)) + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly + A = torch_matmul(Q, K.transpose(2, 3)) - # Logit softcapping - A /= t; torch_tanh(A, out = A); A *= t; + # Logit softcapping + A /= t; torch_tanh(A, out = A); A *= t; + else: + A = torch_matmul(Q, K.transpose(2, 3)) * scale + A += causal_mask[:q_len, :q_len] # Much slower in torch compile! # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index e67a9e5fad..e9f0fd2383 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .granite import FastGraniteModel from .loader import FastLanguageModel from .llama import FastLlamaModel from .mistral import FastMistralModel diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 68e294f157..b792cf3e2f 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -94,7 +94,7 @@ def patch_mistral_nemo_config(config): from transformers import __version__ as transformers_version from transformers import PretrainedConfig -model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",] +model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2", "granite"] for model_name in model_architectures: config_filepath = f"transformers.models.{model_name}.configuration_{model_name}" diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index bf40ea8a27..ff30c2dc51 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -159,7 +159,7 @@ def Gemma2Attention_fast_forward( fx = slow_inference_attention_softcapping \ if "_flag_for_generation" in kwargs else \ slow_attention_softcapping - A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len) + A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len, is_gemma2=True) pass A = self.apply_o(self, A) return A, None, past_key_value diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py new file mode 100644 index 0000000000..209bcabafe --- /dev/null +++ b/unsloth/models/granite.py @@ -0,0 +1,545 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .llama import * +import os +from ._utils import __version__ +from .llama import ( + LlamaRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, +) +from .mistral import * + +try: + from transformers.models.granite.modeling_granite import ( + GraniteAttention, + GraniteDecoderLayer, + GraniteModel, + GraniteForCausalLM, + ) +except: + from packaging.version import Version + + transformers_version = Version(transformers_version) + if not transformers_version >= Version("4.45.0"): + raise ImportError( + f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\ + f"The minimum required version is 4.42.3.\n"\ + f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\ + f"to obtain the latest transformers build, then restart this session."\ + ) + pass +pass + +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) + +# For Pytorch 2.1.1 +try: + from transformers.models.granite.modeling_granite import ( + GraniteSdpaAttention, + GraniteFlashAttention2, + ) +except: + GraniteSdpaAttention = GraniteAttention + GraniteFlashAttention2 = GraniteAttention +pass + +def GraniteAttention_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + # Clear inference + if hasattr(self, "paged_attention"): + del self.paged_attention_K + del self.paged_attention_V + del self.paged_attention + del self.temp_QA + del self.temp_KV + del self.RH_Q + del self.attention + pass + + bsz, q_len, _ = hidden_states.size() + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + assert(n_kv_heads * n_groups == n_heads) + + Q, K, V = self.apply_qkv(self, hidden_states) + Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) + K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) + + kv_seq_len = K.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + assert position_embeddings is not None + cos, sin = position_embeddings + if position_ids is None: + Q, K = fast_rope_embedding(Q, K, cos, sin) + else: + Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids) + + if past_key_value is not None: + K = torch.cat([past_key_value[0], K], dim = 2) + V = torch.cat([past_key_value[1], V], dim = 2) + pass + past_key_value = (K, V) if use_cache else None + + # Attention module + if (not HAS_FLASH_ATTENTION and attention_mask is None): + # Xformers memory efficient attention + Q = Q.transpose(1, 2) + K = K.transpose(1, 2) + V = V.transpose(1, 2) + K_M = V_M = bsz * kv_seq_len + Q_M = bsz * q_len + + has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask) + + # Group query attention + K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) + V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) + K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) + V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim) + if hidden_states.requires_grad: + K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) + V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) + + if has_swa: + Q = Q.view(1, Q_M, n_heads, head_dim) + K = K.view(1, K_M, n_heads, head_dim) + V = V.view(1, V_M, n_heads, head_dim) + pass + else: + # Xformers does support the forward pass though + Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) + + if has_swa: + Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim) + K = K.view(1, K_M, n_kv_heads, n_groups, head_dim) + V = V.view(1, V_M, n_kv_heads, n_groups, head_dim) + pass + pass + + A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling) + A = A.view(bsz, q_len, n_heads, head_dim) + + elif HAS_FLASH_ATTENTION and attention_mask is None: + Q = Q.transpose(1, 2) + K = K.transpose(1, 2) + V = V.transpose(1, 2) + sw = getattr(self.config, "sliding_window", None) + sw = kv_seq_len if (sw is None or sw == "null") else sw + window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) + A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling) + else: + # Grouped query attention + # if n_groups != 1: + K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim) + K = K.reshape(bsz, n_heads, kv_seq_len, head_dim) + V = V.reshape(bsz, n_heads, kv_seq_len, head_dim) + # pass + # Must be contiguous or else results are False! + # https://github.com/pytorch/pytorch/issues/112577 + Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous() + # Needs (batch_size, n_heads, seq_len, head_dim) + # is_casual and attention_mask must not be both set! + A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, scale = self.scaling, is_causal = False) + # Go back to (batch_size, seq_len, n_heads, head_dim) + A = A.transpose(1, 2).contiguous() + pass + + attn_output = A.reshape(bsz, q_len, n_heads*head_dim) + attn_output = self.apply_o(self, attn_output) + attn_weights = None + return attn_output, attn_weights, past_key_value +pass + + +def GraniteDecoderLayer_fast_forward( + self, + hidden_states: torch.Tensor, + causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + *args, **kwargs, +): + if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None: + residual = hidden_states + hidden_states = fast_rms_layernorm_inference(self.input_layernorm, hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + position_embeddings = position_embeddings, + _flag_for_generation=True, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states) + hidden_states = fast_swiglu_inference(self.mlp, hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + else: + residual = hidden_states + hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + causal_mask=causal_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + position_embeddings = position_embeddings, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + # Fully Connected + residual = hidden_states + hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + pass + + outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights,) + if use_cache: outputs += (present_key_value,) + return outputs +pass + + +from math import sqrt as math_sqrt +KV_CACHE_INCREMENT = 256 # KV Cache update size +torch_nn_functional_softmax = torch.nn.functional.softmax +torch_matmul = torch.matmul +torch_tanh = torch.tanh + +def GraniteAttention_fast_forward_inference( + self, + hidden_states: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]], + position_ids, + do_prefill = False, + attention_mask = None, + use_sliding_window = False, +): + Xn = hidden_states + bsz, _, hd = hidden_states.size() + K1, V1 = past_key_value + dtype = Xn.dtype + + n_heads = self.num_heads + n_groups = self.num_key_value_groups + n_kv_heads = self.num_key_value_heads + head_dim = self.head_dim + attention_size = n_heads*head_dim + # assert(n_kv_heads * n_groups == n_heads) + seq_len = K1.shape[-2] + kv_seq_len = seq_len + 1 + + # Prefill phase + # if not hasattr(self, "paged_attention"): + if do_prefill: + self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0") + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3) + self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3) + self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0") + self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0") + self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + # Only for Gemma2 + self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0") + self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0") + + + self.half_head_dim = head_dim // 2 + elif kv_seq_len >= self.paged_attention.shape[0]: + self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim)) + self.paged_attention_K = self.paged_attention[:,0] + self.paged_attention_V = self.paged_attention[:,1] + self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT)) + pass + + Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0]) + Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0]) + Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1]) + Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2) + Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2) + + # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) + # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) + cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1) + sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1) + h = self.half_head_dim + + RH_Q = self.RH_Q + RH_Q[:,:,:,:h] = Qn[:,:,:,h:] + RH_Q[:,:,:,h:] = Qn[:,:,:,:h] + torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h]) + Qn *= cos + Qn.addcmul_(RH_Q, sin) + + RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0") + RH_K[:,:,:,:h] = Kn[:,:,:,h:] + RH_K[:,:,:,h:] = Kn[:,:,:,:h] + torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h]) + Kn *= cos + Kn.addcmul_(RH_K, sin) + + # New KV cache + # Kn = torch.cat([K1, Kn], dim = 2) + # Vn = torch.cat([V1, Vn], dim = 2) + self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3) + self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3) + Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) + Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) + + # Handle sliding windows + sliding_window = self.config.sliding_window + if use_sliding_window and kv_seq_len > sliding_window: + # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193 + slicing_tokens = 1 - sliding_window + Knn = Kn[:, :, slicing_tokens:, :]#.contiguous() + Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous() + else: + Knn, Vnn = Kn, Vn + pass + + # Grouped query attention + _, _, cached_len, _ = Knn.shape + if n_groups != 1: + Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) + Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim) + pass + # else: + # Knn, Vnn = Knn, Vnn + # pass + + Qn *= self.scaling + A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) + + # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched + + A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) + A = torch_matmul(A, Vnn, out = Qn) + # else: + # A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + # pass + A = A.transpose(1, 2) + A = A.reshape(bsz, 1, attention_size) + A = fast_linear_forward(self.o_proj, A, out = self.temp_O) + return A, (Kn, Vn) +pass + + +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825 +# @torch.inference_mode +def GraniteModel_fast_forward_inference( + self, + input_ids, + past_key_values, + position_ids, + attention_mask = None, +): + input_ids = input_ids[:,:self.max_seq_length] + hidden_states = self.model.embed_tokens(input_ids) + hidden_states *= self.embedding_multiplier + hidden_states = hidden_states.to(self.config.torch_dtype) + + bsz, q_len, hd = hidden_states.shape + seq_len = past_key_values[0][0].shape[-2] + if bsz != 1: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (bsz, q_len), + hidden_states, + seq_len, + ) + else: + attention_mask = None + pass + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + next_decoder_cache = [] + for idx, decoder_layer in enumerate(self.model.layers): + + residual = hidden_states + hidden_states = fast_rms_layernorm_inference(decoder_layer.input_layernorm, hidden_states) + hidden_states, present_key_value = GraniteAttention_fast_forward_inference( + decoder_layer.self_attn, + hidden_states = hidden_states, + past_key_value = past_key_values[idx], + position_ids = position_ids, + attention_mask = attention_mask, + do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), + position_embeddings = position_embeddings, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + + residual = hidden_states + hidden_states = fast_rms_layernorm_inference(decoder_layer. pre_feedforward_layernorm, hidden_states) + hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + next_decoder_cache.append(present_key_value) + pass + hidden_states = fast_rms_layernorm_inference(self.model.norm, hidden_states) + + return BaseModelOutputWithPast( + last_hidden_state = hidden_states, + past_key_values = next_decoder_cache, + hidden_states = [], + attentions = [], + ) +pass + +class GraniteRotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, config): + super().__init__(config = config) + +class FastGraniteModel(FastLlamaModel): + + @staticmethod + def pre_patch(): + init_name, function = patch_linear_scaling( + model_name = "granite", + rope_module = GraniteRotaryEmbedding, + scaled_rope_module = LlamaLinearScalingRotaryEmbedding, + attention_module = GraniteAttention, + ) + if init_name is not None: + exec(function, globals()) + GraniteAttention.__init__ = eval(init_name) + pass + GraniteAttention .forward = GraniteAttention_fast_forward + GraniteSdpaAttention .forward = GraniteAttention_fast_forward + GraniteFlashAttention2.forward = GraniteAttention_fast_forward + GraniteDecoderLayer .forward = GraniteDecoderLayer_fast_forward + GraniteModel .forward = LlamaModel_fast_forward + GraniteForCausalLM .forward = CausalLM_fast_forward(GraniteModel_fast_forward_inference) + PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward + fix_prepare_inputs_for_generation(GraniteForCausalLM) + + import transformers.models.granite.modeling_granite + transformers.models.granite.modeling_granite.GraniteRotaryEmbedding = GraniteRotaryEmbedding + + return + pass + + + @staticmethod + def post_patch(model): + + # Torch.compile fails on embedding matrix?? + # Workaround randomnly fixes it for torch versions < 2.2 + model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight) + model.config.update({"unsloth_version" : __version__}) + + # We also do this for the lm_head + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + lm_head.weight = model.lm_head.weight + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + model.lm_head = lm_head + + # Granite has tied weights! This means lm_head == embed_tokens + if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr(): + lm_head = torch.nn.Linear(1, 1, bias = None) + del lm_head.weight + lm_head.weight = model.model.embed_tokens.weight + lm_head.in_features = lm_head.weight.shape[1] + lm_head.out_features = lm_head.weight.shape[0] + model.lm_head = lm_head + pass + + # Also patch all dtypes - BnB seems to not allocate the correct type? + # BnB default dtype seems to be float16! + correct_dtype = lm_head.weight.dtype + + for name, module in model.named_modules(): + if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)): + weight = module.weight + quant_state = weight.quant_state + + if type(quant_state) is list: + # BnB seems to have float16 as default! + module.weight.quant_state[2] = correct_dtype # Cast to correct dtype + else: + # https://github.com/TimDettmers/bitsandbytes/pull/763/files + quant_state.dtype = correct_dtype + pass + pass + # Downcast RoPE embedding to correct data type + if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")): + + if hasattr(module, "cos_cached") and \ + (module.cos_cached.dtype != correct_dtype): + + module.cos_cached = module.cos_cached.to(correct_dtype) + module.sin_cached = module.sin_cached.to(correct_dtype) + + elif hasattr(module, "short_cos_cached") and \ + (module.short_cos_cached.dtype != correct_dtype): + + module.short_cos_cached = module.short_cos_cached.to(correct_dtype) + module.short_sin_cached = module.short_sin_cached.to(correct_dtype) + pass + pass + pass + + # Clear deleted GPU items + import gc + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + return model + pass +pass + diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 9c5499dc75..78d366d37a 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -617,6 +617,7 @@ def LlamaModel_fast_forward( IS_GEMMA = self.config.model_type.startswith("gemma") IS_GEMMA2 = self.config.model_type.startswith("gemma2") IS_COHERE = self.config.model_type.startswith("cohere") + IS_GRANITE = self.config.model_type.startswith("granite") train_embed_tokens = self.embed_tokens.weight.requires_grad if IS_GEMMA: @@ -682,6 +683,8 @@ def LlamaModel_fast_forward( pass hidden_states = inputs_embeds + if IS_GRANITE: #granite has embedding multiplier + hidden_states = self.embedding_multiplier * hidden_states if past_key_values is None and self.training: use_cache = False @@ -763,6 +766,12 @@ def LlamaModel_fast_forward( pass pass + + if IS_GRANITE: + position_embeddings = self.rotary_emb(hidden_states, position_ids, self.max_position_embeddings) + else: + position_embeddings = None + # Go through every layer! for idx, decoder_layer in enumerate(self.layers): @@ -782,12 +791,14 @@ def LlamaModel_fast_forward( past_key_values, output_attentions, use_cache, + None, + position_embeddings, )[0] elif gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask) + return module(*inputs, past_key_value, output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings) return custom_forward pass @@ -812,6 +823,7 @@ def custom_forward(*inputs): output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + position_embeddings = position_embeddings ) hidden_states = layer_outputs[0] pass @@ -974,6 +986,9 @@ def _CausalLM_fast_forward( loss = None logit_softcapping = getattr(self.config, "final_logit_softcapping", 0) logit_scaling = getattr(self.config, "logit_scale", 0) + if self.config.model_type == "granite": + # granite uses logit_scaling as key and they divide by the scale unlike cohere + logit_scaling = 1 / getattr(self.config, "logits_scaling", 1) if labels is not None: shift_logits = logits if not hasattr(self, "extra_ignored_labels"): @@ -2260,6 +2275,7 @@ def patch_peft_model( elif model_type == "gemma": apply_lora_mlp = apply_lora_mlp_geglu_approx elif model_type == "gemma2": apply_lora_mlp = apply_lora_mlp_geglu_approx elif model_type == "cohere": apply_lora_mlp = apply_lora_mlp_swiglu + elif model_type == "granite": apply_lora_mlp = apply_lora_mlp_swiglu else: raise NotImplementedError(f"Unsloth: {model_type} is not yet implemented!") pass diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index db7259b1d9..75b4adaa14 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -13,6 +13,7 @@ # limitations under the License. from ._utils import is_bfloat16_supported, HAS_FLASH_ATTENTION, HAS_FLASH_ATTENTION_SOFTCAPPING +from .granite import FastGraniteModel from .llama import FastLlamaModel, logger from .mistral import FastMistralModel from .qwen2 import FastQwen2Model @@ -38,6 +39,7 @@ SUPPORTS_GEMMA2 = transformers_version >= Version("4.42") SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2") SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0") +SUPPORTS_GRANITE = transformers_version >= Version("4.46.0") if SUPPORTS_GEMMA: from .gemma import FastGemmaModel if SUPPORTS_GEMMA2: @@ -256,7 +258,7 @@ def from_pretrained( model_type = model_config.model_type - if model_type == "llama": + if model_type == "llama": scaling_type = None if getattr(model_config, "rope_scaling", None) is not None: scaling_type1 = model_config.rope_scaling.get("type", None) @@ -312,6 +314,8 @@ def from_pretrained( dispatch_model = FastQwen2Model elif model_type == "cohere": dispatch_model = FastCohereModel + elif model_type == "granite": + dispatch_model = FastGraniteModel else: raise NotImplementedError( f"Unsloth: {model_name} not supported yet!\n"\ From 1713728257554a9c823cdd2f510f42845a2a2598 Mon Sep 17 00:00:00 2001 From: datta0 Date: Tue, 29 Oct 2024 05:04:14 +0000 Subject: [PATCH 2/5] Fixup inference --- unsloth/models/granite.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 209bcabafe..428453e07e 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -263,7 +263,11 @@ def GraniteAttention_fast_forward_inference( do_prefill = False, attention_mask = None, use_sliding_window = False, + position_embeddings : Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): + + assert position_embeddings is not None, f"Granite model requires position embeddings to be specified" + Xn = hidden_states bsz, _, hd = hidden_states.size() K1, V1 = past_key_value @@ -311,8 +315,8 @@ def GraniteAttention_fast_forward_inference( # cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len) # Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids) - cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1) - sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1) + cos, sin = position_embeddings + cos, sin = cos[position_ids], sin[position_ids] h = self.half_head_dim RH_Q = self.RH_Q @@ -338,7 +342,7 @@ def GraniteAttention_fast_forward_inference( Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) # Handle sliding windows - sliding_window = self.config.sliding_window + sliding_window = self.config.sliding_window if hasattr(self.config, "sliding_window") else self.config.max_position_embeddings if use_sliding_window and kv_seq_len > sliding_window: # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193 slicing_tokens = 1 - sliding_window @@ -388,7 +392,7 @@ def GraniteModel_fast_forward_inference( ): input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) - hidden_states *= self.embedding_multiplier + hidden_states *= self.model.embedding_multiplier hidden_states = hidden_states.to(self.config.torch_dtype) bsz, q_len, hd = hidden_states.shape @@ -404,7 +408,7 @@ def GraniteModel_fast_forward_inference( attention_mask = None pass - position_embeddings = self.rotary_emb(hidden_states, position_ids) + position_embeddings = self.model.rotary_emb(hidden_states, position_ids, self.max_seq_length) next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): @@ -420,12 +424,12 @@ def GraniteModel_fast_forward_inference( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), position_embeddings = position_embeddings, ) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = residual + hidden_states * self.config.residual_multiplier residual = hidden_states - hidden_states = fast_rms_layernorm_inference(decoder_layer. pre_feedforward_layernorm, hidden_states) + hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = residual + hidden_states * self.config.residual_multiplier next_decoder_cache.append(present_key_value) pass From afb3f019a1ee25b343e33ff857cafd62934315b9 Mon Sep 17 00:00:00 2001 From: datta0 Date: Tue, 29 Oct 2024 17:00:05 +0000 Subject: [PATCH 3/5] Cleanup flex attention --- unsloth/kernels/flex_attention.py | 80 +++++++++++-------------------- unsloth/models/gemma2.py | 2 +- 2 files changed, 28 insertions(+), 54 deletions(-) diff --git a/unsloth/kernels/flex_attention.py b/unsloth/kernels/flex_attention.py index 7edf88e406..08426b69e0 100644 --- a/unsloth/kernels/flex_attention.py +++ b/unsloth/kernels/flex_attention.py @@ -41,7 +41,7 @@ # Logit softcapping @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options) - def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale=1, is_gemma2=False): + def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim n_kv_heads = self.num_key_value_heads @@ -53,22 +53,14 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale=1, K = K.reshape(bsz, n_heads, q_len, head_dim) V = V.reshape(bsz, n_heads, q_len, head_dim) - if is_gemma2: - # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below - # We default to using the config file itself - # s = self.config.hidden_size // self.config.num_attention_heads - s = self.config.query_pre_attn_scalar - t = self.config.attn_logit_softcapping - - Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly - A = torch_matmul(Q, K.transpose(2, 3)) - - # Logit softcapping - A /= t; torch_tanh(A, out = A); A *= t; - else: - A = torch_matmul(Q, K.transpose(2, 3)) * scale + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below + # We default to using the config file itself + # s = self.config.hidden_size // self.config.num_attention_heads + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly A = torch.matmul(Q, K.transpose(2, 3)) A = t * torch.tanh(A / t) # Logit softcapping A += causal_mask[:q_len, :q_len] @@ -94,9 +86,6 @@ def tanh_softcap(x, b, h, q_idx, kv_idx): return t * torch.tanh(x / t) return tanh_softcap pass - def noop(score, b, h, q_idx, kv_idx): - return score - pass def causal_masker(b, h, q_idx, kv_idx): return q_idx >= kv_idx pass @@ -131,31 +120,20 @@ def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_win pass @functools.lru_cache - def flex_attention(s=1,t=1, is_gemma2=False): - if is_gemma2: - scale = 1.0 / math.sqrt(s) - enable_gqa = True - score_mod = generate_tanh_softcap(s, t) - else: - # mostly for granite - scale = s - enable_gqa = False - score_mod = noop() + def flex_attention(s, t): + scale = 1.0 / math.sqrt(s) + score_mod = generate_tanh_softcap(t) return functools.partial( - _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = enable_gqa, + _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True, ) pass - def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale=1, is_gemma2=False): + def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim - if is_gemma2: - s = 1.0/ math.sqrt(self.config.query_pre_attn_scalar) - t = self.config.attn_logit_softcapping - else: - s = scale - t = 1.0 - fx = flex_attention(s,t, is_gemma2) + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping + fx = flex_attention(s, t) A = fx(query = Q, key = K, value = V, block_mask = causal_mask) A = A.transpose(1, 2).contiguous() A = A.reshape(bsz, q_len, n_heads*head_dim) @@ -167,7 +145,7 @@ def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale=1, torch_matmul = torch.matmul torch_tanh = torch.tanh torch_nn_functional_softmax = torch.nn.functional.softmax -def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, scale, is_gemma2=True): +def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len): n_heads = self.num_heads head_dim = self.head_dim n_kv_heads = self.num_key_value_heads @@ -179,22 +157,18 @@ def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len, K = K.reshape(bsz, n_heads, q_len, head_dim) V = V.reshape(bsz, n_heads, q_len, head_dim) - if is_gemma2: - # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below - # We default to using the config file itself - # s = self.config.hidden_size // self.config.num_attention_heads - s = self.config.query_pre_attn_scalar - t = self.config.attn_logit_softcapping + # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e + # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below + # We default to using the config file itself + # s = self.config.hidden_size // self.config.num_attention_heads + s = self.config.query_pre_attn_scalar + t = self.config.attn_logit_softcapping - Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly - A = torch_matmul(Q, K.transpose(2, 3)) + Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly + A = torch_matmul(Q, K.transpose(2, 3)) - # Logit softcapping - A /= t; torch_tanh(A, out = A); A *= t; - else: - A = torch_matmul(Q, K.transpose(2, 3)) * scale - + # Logit softcapping + A /= t; torch_tanh(A, out = A); A *= t; A += causal_mask[:q_len, :q_len] # Much slower in torch compile! # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf")) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index ff30c2dc51..bf40ea8a27 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -159,7 +159,7 @@ def Gemma2Attention_fast_forward( fx = slow_inference_attention_softcapping \ if "_flag_for_generation" in kwargs else \ slow_attention_softcapping - A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len, is_gemma2=True) + A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len) pass A = self.apply_o(self, A) return A, None, past_key_value From 7b777f6111b3caaa2d87d565ddbc4c4891d0792f Mon Sep 17 00:00:00 2001 From: datta0 Date: Tue, 26 Nov 2024 12:54:16 +0000 Subject: [PATCH 4/5] remove sliding window --- unsloth/models/gemma2.py | 2 +- unsloth/models/granite.py | 51 +++++++++------------------------------ unsloth/models/llama.py | 9 ++++--- 3 files changed, 19 insertions(+), 43 deletions(-) diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 4eb9d64313..12925fd9cf 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -194,7 +194,7 @@ def Gemma2DecoderLayer_fast_forward( output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, - _flag_for_generation=True, + _flag_for_generation=self._flag_for_generation, ) hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight) hidden_states += residual diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 428453e07e..9431cf5ede 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -121,8 +121,6 @@ def GraniteAttention_fast_forward( K_M = V_M = bsz * kv_seq_len Q_M = bsz * q_len - has_swa = isinstance(causal_mask, xformers.attn_bias.BlockDiagonalCausalMask) - # Group query attention K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim) @@ -131,21 +129,9 @@ def GraniteAttention_fast_forward( if hidden_states.requires_grad: K = K.reshape(bsz, kv_seq_len, n_heads, head_dim) V = V.reshape(bsz, kv_seq_len, n_heads, head_dim) - - if has_swa: - Q = Q.view(1, Q_M, n_heads, head_dim) - K = K.view(1, K_M, n_heads, head_dim) - V = V.view(1, V_M, n_heads, head_dim) - pass else: # Xformers does support the forward pass though Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim) - - if has_swa: - Q = Q.view(1, Q_M, n_kv_heads, n_groups, head_dim) - K = K.view(1, K_M, n_kv_heads, n_groups, head_dim) - V = V.view(1, V_M, n_kv_heads, n_groups, head_dim) - pass pass A = xformers_attention(Q, K, V, attn_bias = causal_mask, scale=self.scaling) @@ -155,9 +141,7 @@ def GraniteAttention_fast_forward( Q = Q.transpose(1, 2) K = K.transpose(1, 2) V = V.transpose(1, 2) - sw = getattr(self.config, "sliding_window", None) - sw = kv_seq_len if (sw is None or sw == "null") else sw - window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) + window = (kv_seq_len, kv_seq_len) A = flash_attn_func(Q, K, V, causal = True, window_size = window, softmax_scale=self.scaling) else: # Grouped query attention @@ -210,7 +194,7 @@ def GraniteDecoderLayer_fast_forward( use_cache=use_cache, padding_mask=padding_mask, position_embeddings = position_embeddings, - _flag_for_generation=True, + _flag_for_generation=self._flag_for_generation, ) hidden_states = residual + hidden_states * self.residual_multiplier @@ -341,38 +325,27 @@ def GraniteAttention_fast_forward_inference( Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3) Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3) - # Handle sliding windows - sliding_window = self.config.sliding_window if hasattr(self.config, "sliding_window") else self.config.max_position_embeddings - if use_sliding_window and kv_seq_len > sliding_window: - # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193 - slicing_tokens = 1 - sliding_window - Knn = Kn[:, :, slicing_tokens:, :]#.contiguous() - Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous() - else: - Knn, Vnn = Kn, Vn - pass - # Grouped query attention - _, _, cached_len, _ = Knn.shape + _, _, cached_len, _ = Kn.shape if n_groups != 1: - Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) - Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) - Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim) - Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim) + Kn = Kn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Vn = Vn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim) + Kn = Kn.reshape(bsz, n_heads, cached_len, head_dim) + Vn = Vn.reshape(bsz, n_heads, cached_len, head_dim) pass # else: - # Knn, Vnn = Knn, Vnn + # Kn, Vn = Kn, Vn # pass Qn *= self.scaling - A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) + A = torch_matmul(Qn, Kn.transpose(2, 3), out = self.attention[:,:,:,:cached_len]) # if attention_mask is not None: A += attention_mask # Must add attention_mask for batched A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype) - A = torch_matmul(A, Vnn, out = Qn) + A = torch_matmul(A, Vn, out = Qn) # else: - # A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False) + # A = scaled_dot_product_attention(Qn, Kn, Vn, attn_mask = attention_mask, is_causal = False) # pass A = A.transpose(1, 2) A = A.reshape(bsz, 1, attention_size) @@ -392,8 +365,8 @@ def GraniteModel_fast_forward_inference( ): input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) - hidden_states *= self.model.embedding_multiplier hidden_states = hidden_states.to(self.config.torch_dtype) + hidden_states *= self.model.embedding_multiplier bsz, q_len, hd = hidden_states.shape seq_len = past_key_values[0][0].shape[-2] diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 710f27e489..dbcd827c2d 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -616,9 +616,9 @@ def LlamaModel_fast_forward( pass # Normalized from Gemma - IS_GEMMA = self.config.model_type.startswith("gemma") - IS_GEMMA2 = self.config.model_type.startswith("gemma2") - IS_COHERE = self.config.model_type.startswith("cohere") + IS_GEMMA = self.config.model_type.startswith("gemma") + IS_GEMMA2 = self.config.model_type.startswith("gemma2") + IS_COHERE = self.config.model_type.startswith("cohere") IS_GRANITE = self.config.model_type.startswith("granite") train_embed_tokens = self.embed_tokens.weight.requires_grad @@ -990,6 +990,9 @@ def _CausalLM_fast_forward( logit_scaling = getattr(self.config, "logit_scale", 0) if self.config.model_type == "granite": # granite uses logit_scaling as key and they divide by the scale unlike cohere + # notice that for granite, logits_scale is 16 and for cohere it is 0.125 (aka 1/8) in their respective configs + # granite: https://github.com/huggingface/transformers/blob/4d1d0f29a493098e6bc6b904b82e29cb331827f5/src/transformers/models/granite/modeling_granite.py#L1103 + # cohere: https://github.com/huggingface/transformers/blob/4d1d0f29a493098e6bc6b904b82e29cb331827f5/src/transformers/models/cohere/modeling_cohere.py#L1176 logit_scaling = 1 / getattr(self.config, "logits_scaling", 1) if labels is not None: shift_logits = logits From fdd1d95ee2e73d236a36546610b28ff175806a7f Mon Sep 17 00:00:00 2001 From: datta0 Date: Tue, 26 Nov 2024 13:07:38 +0000 Subject: [PATCH 5/5] Use torch.add for residual multiplier --- unsloth/models/granite.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index 9431cf5ede..2229636e9e 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -196,13 +196,13 @@ def GraniteDecoderLayer_fast_forward( position_embeddings = position_embeddings, _flag_for_generation=self._flag_for_generation, ) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm_inference(self.post_attention_layernorm, hidden_states) hidden_states = fast_swiglu_inference(self.mlp, hidden_states) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) else: residual = hidden_states hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states) @@ -217,13 +217,13 @@ def GraniteDecoderLayer_fast_forward( padding_mask=padding_mask, position_embeddings = position_embeddings, ) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) # Fully Connected residual = hidden_states hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states * self.residual_multiplier + hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) pass outputs = (hidden_states,) @@ -397,12 +397,13 @@ def GraniteModel_fast_forward_inference( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), position_embeddings = position_embeddings, ) - hidden_states = residual + hidden_states * self.config.residual_multiplier + + hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) residual = hidden_states hidden_states = fast_rms_layernorm_inference(decoder_layer.post_attention_layernorm, hidden_states) hidden_states = fast_swiglu_inference(decoder_layer.mlp, hidden_states) - hidden_states = residual + hidden_states * self.config.residual_multiplier + hidden_states = torch.add(residual, hidden_states, alpha = self.config.residual_multiplier) next_decoder_cache.append(present_key_value) pass