diff --git a/unsloth/kernels/__init__.py b/unsloth/kernels/__init__.py index 15913413d9..6780e09608 100644 --- a/unsloth/kernels/__init__.py +++ b/unsloth/kernels/__init__.py @@ -60,6 +60,13 @@ create_flex_attention_causal_mask, create_flex_attention_sliding_window_mask, ) +from ..utils.attnres import ( + AttnResConfig, + AttnResState, + apply_attnres, + begin_attnres, + end_attnres, +) import os diff --git a/unsloth/models/attnres.py b/unsloth/models/attnres.py new file mode 100644 index 0000000000..ca5ec6fbc4 --- /dev/null +++ b/unsloth/models/attnres.py @@ -0,0 +1,207 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass, field +import math +import os +from typing import Optional + +import torch + + +def _read_flag(value, default = False): + if value is None: + return default + if isinstance(value, str): + value = value.strip().lower() + return value in ("1", "true", "yes", "on") + return bool(value) + + +def _read_int(value, default): + try: + parsed = int(value) + except (ValueError, TypeError): + return default + return max(1, parsed) + + +def _read_float(value, default): + try: + return float(value) + except (ValueError, TypeError): + return default + + +@dataclass +class _AttnResState: + enabled: bool + block_size: int + alpha: float + num_layers: int + completed_block_summaries: list[torch.Tensor] = field(default_factory = list) + current_block_states: list[torch.Tensor] = field(default_factory = list) + + +def _get_config_value(config, names, default = None): + for name in names: + if hasattr(config, name): + value = getattr(config, name) + if value is not None: + return value + return default + + +def _build_state(model, *, use_cache = False): + config = getattr(model, "config", None) + if config is None: + return None + + enabled = _read_flag( + _get_config_value( + config, + ( + "attnres", + "attn_residual", + "attn_residuals", + "use_attnres", + "use_attn_residuals", + "_attnres", + "_use_attnres", + ), + default = None, + ), + default = False, + ) + if not enabled: + enabled = _read_flag(os.environ.get("UNSLOTH_ATTNRES"), default = False) + if not enabled: + enabled = _read_flag(os.environ.get("ATTNRES"), default = False) + if not enabled: + return None + + # Stateful accumulation can desync across recomputation passes. + if ( + getattr(model, "training", False) + and getattr(model, "gradient_checkpointing", False) + and not use_cache + ): + return None + + block_size = _read_int( + _get_config_value( + config, + ("attnres_block_size", "attn_residual_block_size", "attnres_block"), + default = os.environ.get("UNSLOTH_ATTNRES_BLOCK_SIZE", 8), + ), + default = 8, + ) + alpha = _read_float( + _get_config_value( + config, + ("attnres_alpha", "attn_residual_alpha"), + default = os.environ.get("UNSLOTH_ATTNRES_ALPHA", 1.0), + ), + default = 1.0, + ) + num_layers = int(getattr(config, "num_hidden_layers", 0)) + if num_layers <= 0 and hasattr(model, "layers"): + num_layers = len(model.layers) + return _AttnResState( + enabled = True, + block_size = block_size, + alpha = alpha, + num_layers = num_layers, + ) + + +def begin_attnres_state(model, *, use_cache = False): + return _build_state(model, use_cache = use_cache) + + +def attnres_init_forward_state( + model, + hidden_states = None, + attention_mask = None, + position_ids = None, + past_key_values = None, + use_cache = False, + output_attentions = False, + output_hidden_states = False, +): + return begin_attnres_state(model, use_cache = use_cache) + + +def _compute_residual_mix( + query: torch.Tensor, + candidates: list[torch.Tensor], +) -> torch.Tensor: + if len(candidates) == 0: + return torch.zeros_like(query) + + # (bsz, seqlen, n_states, dim) + stacked = torch.stack(candidates, dim = 2) + dim = query.shape[-1] + scale = 1.0 / math.sqrt(float(dim)) + logits = (query.unsqueeze(2) * stacked).sum(dim = -1) * scale + weights = torch.softmax(logits, dim = -1) + return (weights.unsqueeze(-1) * stacked).sum(dim = 2) + + +def attnres_transform_attention_output( + attention_output: torch.Tensor, + attnres_state: Optional[_AttnResState] = None, + attnres_layer_idx: Optional[int] = None, + residual: Optional[torch.Tensor] = None, + attention_mask = None, + causal_mask = None, + position_ids = None, +): + if attnres_state is None or not getattr(attnres_state, "enabled", False): + return attention_output + + layer_idx = 0 if attnres_layer_idx is None else int(attnres_layer_idx) + query = residual if residual is not None else attention_output + + candidates = list(attnres_state.completed_block_summaries) + candidates.extend(attnres_state.current_block_states) + if len(candidates) != 0: + mixed = _compute_residual_mix(query, candidates) + attention_output = attention_output + (attnres_state.alpha * mixed) + + # Keep current-layer information for future layers in this block. + attnres_state.current_block_states.append(attention_output.clone()) + + # Finalize a block summary at boundaries. + at_block_end = ((layer_idx + 1) % attnres_state.block_size) == 0 + at_model_end = (attnres_state.num_layers > 0) and ( + (layer_idx + 1) >= attnres_state.num_layers + ) + if at_block_end or at_model_end: + block_summary = torch.stack(attnres_state.current_block_states, dim = 0).sum( + dim = 0 + ) + attnres_state.completed_block_summaries.append(block_summary) + attnres_state.current_block_states.clear() + + return attention_output + + +__all__ = [ + "begin_attnres_state", + "attnres_init_forward_state", + "attnres_transform_attention_output", +] diff --git a/unsloth/models/cohere.py b/unsloth/models/cohere.py index 294e8d0c7e..df8d67f27a 100644 --- a/unsloth/models/cohere.py +++ b/unsloth/models/cohere.py @@ -13,6 +13,7 @@ # limitations under the License. from .llama import * +from .llama import attnres_init_forward_state, attnres_transform_attention_output from ._utils import __version__ from unsloth_zoo.hf_utils import dtype_from_config from unsloth_zoo.utils import _get_dtype, Version @@ -184,6 +185,8 @@ def CohereDecoderLayer_fast_forward( *args, **kwargs, ): + attnres_state = kwargs.get("attnres_state", None) + attnres_layer_idx = kwargs.get("attnres_layer_idx", None) if use_cache and hasattr( self, "_flag_for_generation" ): # past_key_value is not None: @@ -209,6 +212,15 @@ def CohereDecoderLayer_fast_forward( padding_mask = padding_mask, **kwargs, ) + hidden_states_attention = attnres_transform_attention_output( + hidden_states_attention, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) # Fully Connected hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states) @@ -229,6 +241,15 @@ def CohereDecoderLayer_fast_forward( padding_mask = padding_mask, **kwargs, ) + hidden_states_attention = attnres_transform_attention_output( + hidden_states_attention, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) # Fully Connected hidden_states_mlp = self.mlp(hidden_states) @@ -468,6 +489,15 @@ def CohereModel_fast_forward_inference( else: attention_mask = None + attnres_state = attnres_init_forward_state( + self, + hidden_states = hidden_states, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + use_cache = True, + ) + next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): device_index = getattr(decoder_layer, "_per_layer_device_index", 0) @@ -488,6 +518,14 @@ def CohereModel_fast_forward_inference( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), ) ) + hidden_states_attention = attnres_transform_attention_output( + hidden_states_attention, + attnres_state = attnres_state, + attnres_layer_idx = idx, + residual = residual, + attention_mask = attention_mask, + position_ids = position_ids, + ) hidden_states_mlp = fast_swiglu_inference(decoder_layer.mlp, hidden_states) residual += hidden_states_attention diff --git a/unsloth/models/falcon_h1.py b/unsloth/models/falcon_h1.py index 659d27de54..d7e74c138c 100644 --- a/unsloth/models/falcon_h1.py +++ b/unsloth/models/falcon_h1.py @@ -13,6 +13,7 @@ # limitations under the License. from .llama import * +from .llama import attnres_init_forward_state, attnres_transform_attention_output import os from ._utils import __version__ from unsloth_zoo.utils import Version, _get_dtype @@ -423,6 +424,8 @@ def FalconH1DecoderLayer_fast_forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + attnres_state = kwargs.get("attnres_state", None) + attnres_layer_idx = kwargs.get("attnres_layer_idx", None) if use_cache and hasattr(self, "_flag_for_generation"): residual = hidden_states hidden_states = fast_rms_layernorm_inference( @@ -441,6 +444,15 @@ def FalconH1DecoderLayer_fast_forward( **kwargs, ) attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + attention_hidden_states = attnres_transform_attention_output( + attention_hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) mamba_hidden_states = self.mamba( hidden_states = hidden_states, @@ -486,6 +498,15 @@ def FalconH1DecoderLayer_fast_forward( **kwargs, ) attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + attention_hidden_states = attnres_transform_attention_output( + attention_hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) hidden_states = mamba_hidden_states + attention_hidden_states @@ -560,6 +581,15 @@ def FalconH1Model_fast_forward_inference_custom( else: attention_mask = None + attnres_state = attnres_init_forward_state( + self, + hidden_states = X, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + use_cache = True, + ) + next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): @@ -584,6 +614,14 @@ def FalconH1Model_fast_forward_inference_custom( attention_hidden_states = ( attention_hidden_states * decoder_layer.attn_out_multiplier ) + attention_hidden_states = attnres_transform_attention_output( + attention_hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = idx, + residual = residual, + attention_mask = attention_mask, + position_ids = position_ids, + ) mamba_hidden_states = decoder_layer.mamba( hidden_states = X, cache_params = present_key_value, diff --git a/unsloth/models/gemma2.py b/unsloth/models/gemma2.py index 720c9a7414..450a690c5b 100644 --- a/unsloth/models/gemma2.py +++ b/unsloth/models/gemma2.py @@ -13,6 +13,7 @@ # limitations under the License. from .llama import * +from .llama import attnres_init_forward_state, attnres_transform_attention_output from ._utils import __version__ from unsloth_zoo.utils import _get_dtype, Version from unsloth_zoo.hf_utils import dtype_from_config @@ -218,6 +219,8 @@ def Gemma2DecoderLayer_fast_forward( *args, **kwargs, ): + attnres_state = kwargs.get("attnres_state", None) + attnres_layer_idx = kwargs.get("attnres_layer_idx", None) if use_cache and hasattr( self, "_flag_for_generation" ): # past_key_value is not None: @@ -244,6 +247,15 @@ def Gemma2DecoderLayer_fast_forward( _flag_for_generation = self._flag_for_generation, **kwargs, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) hidden_states = fast_rms_layernorm_inference_gemma( self.post_attention_layernorm, hidden_states, out_weight ) @@ -275,6 +287,15 @@ def Gemma2DecoderLayer_fast_forward( padding_mask = padding_mask, **kwargs, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) hidden_states = fast_rms_layernorm( self.post_attention_layernorm, hidden_states, gemma = True ) @@ -527,6 +548,16 @@ def Gemma2Model_fast_forward_inference( else: SWA = attention_mask GA = attention_mask + + attnres_state = attnres_init_forward_state( + self, + hidden_states = hidden_states, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + use_cache = True, + ) + next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): # For pipeline parallelism, we need to move all tensors to the same device @@ -551,6 +582,14 @@ def Gemma2Model_fast_forward_inference( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), use_sliding_window = use_sliding_window, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = idx, + residual = residual, + attention_mask = SWA if use_sliding_window else GA, + position_ids = position_ids, + ) hidden_states = fast_rms_layernorm_inference_gemma( decoder_layer.post_attention_layernorm, hidden_states, diff --git a/unsloth/models/granite.py b/unsloth/models/granite.py index fea3dc1b36..a5d044f388 100644 --- a/unsloth/models/granite.py +++ b/unsloth/models/granite.py @@ -13,6 +13,7 @@ # limitations under the License. from .llama import * +from .llama import attnres_init_forward_state, attnres_transform_attention_output import os from ._utils import __version__ from unsloth_zoo.utils import _get_dtype, Version @@ -204,6 +205,8 @@ def GraniteDecoderLayer_fast_forward( if hasattr(self, "residual_multiplier") else self.config.residual_multiplier ) + attnres_state = kwargs.get("attnres_state", None) + attnres_layer_idx = kwargs.get("attnres_layer_idx", None) if use_cache and hasattr( self, "_flag_for_generation" @@ -225,6 +228,15 @@ def GraniteDecoderLayer_fast_forward( _flag_for_generation = self._flag_for_generation, **kwargs, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected @@ -249,6 +261,15 @@ def GraniteDecoderLayer_fast_forward( position_embeddings = position_embeddings, **kwargs, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) # Fully Connected @@ -473,6 +494,15 @@ def GraniteModel_fast_forward_inference( self.max_seq_length, hidden_states.device.index ) + attnres_state = attnres_init_forward_state( + self, + hidden_states = hidden_states, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + use_cache = True, + ) + next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): device_index = getattr(decoder_layer, "_per_layer_device_index", 0) @@ -493,6 +523,14 @@ def GraniteModel_fast_forward_inference( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), position_embeddings = position_embeddings, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = idx, + residual = residual, + attention_mask = attention_mask, + position_ids = position_ids, + ) hidden_states = torch.add(residual, hidden_states, alpha = residual_multiplier) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 2f61913550..007a1f9488 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -78,6 +78,11 @@ from ..tokenizer_utils import * from .vision import FastBaseModel +from .attnres import ( + attnres_init_forward_state, + attnres_transform_attention_output, +) + # Final patching code from transformers.models.llama.modeling_llama import ( LlamaAttention, @@ -824,6 +829,8 @@ def LlamaDecoderLayer_fast_forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + attnres_state = kwargs.get("attnres_state", None) + attnres_layer_idx = kwargs.get("attnres_layer_idx", None) if use_cache and hasattr(self, "_flag_for_generation"): residual = hidden_states hidden_states = fast_rms_layernorm_inference( @@ -841,6 +848,15 @@ def LlamaDecoderLayer_fast_forward( position_embeddings = position_embeddings, **kwargs, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) hidden_states += residual # Fully Connected @@ -865,6 +881,15 @@ def LlamaDecoderLayer_fast_forward( position_embeddings = position_embeddings, **kwargs, ) + hidden_states = attnres_transform_attention_output( + hidden_states, + attnres_state = attnres_state, + attnres_layer_idx = attnres_layer_idx, + residual = residual, + attention_mask = attention_mask, + causal_mask = causal_mask, + position_ids = position_ids, + ) hidden_states = residual + hidden_states # Fully Connected @@ -1198,6 +1223,17 @@ def LlamaModel_fast_forward( else: position_embeddings = None + attnres_state = attnres_init_forward_state( + self, + hidden_states = hidden_states, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + use_cache = use_cache, + output_attentions = output_attentions, + output_hidden_states = output_hidden_states, + ) + # Go through every layer! for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -1205,19 +1241,22 @@ def LlamaModel_fast_forward( past_key_value = past_key_values[idx] if past_key_values is not None else None mask = causal_mask + layer_kwargs = dict(kwargs) if IS_GEMMA2: use_sliding_window = idx % 2 == 0 if use_sliding_window: mask = self.SWA_mask if use_static_mask else dynamic_SWA_mask else: mask = self.GA_mask if use_static_mask else dynamic_GA_mask - kwargs["use_sliding_window"] = use_sliding_window + layer_kwargs["use_sliding_window"] = use_sliding_window + layer_kwargs["attnres_state"] = attnres_state + layer_kwargs["attnres_layer_idx"] = idx if gradient_checkpointing and not isinstance( decoder_layer, GradientCheckpointingLayer ): - def create_custom_forward(module): + def create_custom_forward(module, layer_kwargs = layer_kwargs): def custom_forward(*inputs): return module( *inputs, @@ -1225,7 +1264,7 @@ def custom_forward(*inputs): output_attentions, padding_mask = padding_mask, position_embeddings = position_embeddings, - **kwargs, + **layer_kwargs, ) return custom_forward @@ -1252,7 +1291,7 @@ def custom_forward(*inputs): use_cache = use_cache, padding_mask = padding_mask, position_embeddings = position_embeddings, - **kwargs, + **layer_kwargs, ) hidden_states = layer_outputs[0] @@ -1362,6 +1401,15 @@ def LlamaModel_fast_forward_inference_custom( # Compute rotary_seq_len once to avoid per-layer GPU-CPU sync from .item() rotary_seq_len = max(kv_seq_len, int(position_ids.max().item()) + 1) + attnres_state = attnres_init_forward_state( + self, + hidden_states = X, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + use_cache = True, + ) + next_decoder_cache = [] for idx, decoder_layer in enumerate(self.model.layers): @@ -1386,6 +1434,14 @@ def LlamaModel_fast_forward_inference_custom( do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"), rotary_seq_len = rotary_seq_len, ) + X = attnres_transform_attention_output( + X, + attnres_state = attnres_state, + attnres_layer_idx = idx, + residual = residual, + attention_mask = attention_mask, + position_ids = position_ids, + ) X += residual residual.copy_(X) # residual = X diff --git a/unsloth/utils/__init__.py b/unsloth/utils/__init__.py index 9a093fedd7..16d3a51e97 100644 --- a/unsloth/utils/__init__.py +++ b/unsloth/utils/__init__.py @@ -30,6 +30,13 @@ run_attention, select_attention_backend, ) +from .attnres import ( + AttnResConfig, + AttnResState, + apply_attnres, + begin_attnres, + end_attnres, +) __all__ = [ "configure_sample_packing", @@ -39,6 +46,11 @@ "mark_allow_overlength", "AttentionConfig", "AttentionContext", + "AttnResConfig", + "AttnResState", + "apply_attnres", + "begin_attnres", + "end_attnres", "FLASH_VARLEN", "FLASH_DENSE", "XFORMERS", diff --git a/unsloth/utils/attention_dispatch.py b/unsloth/utils/attention_dispatch.py index 51e7c8e7a9..9a14e02cfa 100644 --- a/unsloth/utils/attention_dispatch.py +++ b/unsloth/utils/attention_dispatch.py @@ -24,6 +24,7 @@ from torch import Tensor from torch.nn.functional import scaled_dot_product_attention +from .attnres import AttnResState, apply_attnres from ..models._utils import * from ..utils.packing import ( build_sdpa_packed_attention_mask, @@ -88,6 +89,7 @@ class AttentionContext: attention_mask: Optional[Tensor] causal_mask: Optional[Any] sliding_window: Optional[int] = None + attnres_state: Optional[AttnResState] = None def select_attention_backend(use_varlen: bool = False) -> str: @@ -145,6 +147,9 @@ def run_attention( sdpa_kwargs = config.sdpa_kwargs or {} xformers_kwargs = config.xformers_kwargs or {} + def finalize(output: Tensor) -> Tensor: + return apply_attnres(output, context.attnres_state) + bsz = context.bsz n_heads = context.n_heads q_len = context.q_len @@ -158,7 +163,7 @@ def run_attention( K_f = K.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim) V_f = V.transpose(1, 2).reshape(bsz * q_len, config.n_kv_heads, head_dim) _, cu_seqlens, max_seqlen = context.seq_info - return flash_attn_varlen_func( + out = flash_attn_varlen_func( Q_f, K_f, V_f, @@ -168,13 +173,15 @@ def run_attention( max_seqlen, **flash_varlen_kwargs, ).view(bsz, q_len, n_heads, head_dim) + return finalize(out) elif backend == FLASH_DENSE: Q_t = Q.transpose(1, 2) K_t = K.transpose(1, 2) V_t = V.transpose(1, 2) - return flash_attn_func(Q_t, K_t, V_t, **flash_dense_kwargs).reshape( + out = flash_attn_func(Q_t, K_t, V_t, **flash_dense_kwargs).reshape( bsz, q_len, n_heads, head_dim ) + return finalize(out) elif backend == XFORMERS: attn_bias = build_xformers_block_causal_mask( context.seq_info, @@ -241,7 +248,7 @@ def run_attention( out = out.reshape(bsz, q_len, n_heads, head_dim) else: out = out.view(bsz, q_len, n_heads, head_dim) - return out + return finalize(out) else: local_mask = context.attention_mask is_causal_local = False @@ -329,7 +336,7 @@ def run_attention( if use_sdpa_gqa: kwargs.setdefault("enable_gqa", True) out = scaled_dot_product_attention(Q, K, V, **kwargs) - return out.transpose(1, 2) + return finalize(out.transpose(1, 2)) K_mod = K V_mod = V @@ -349,7 +356,7 @@ def run_attention( V_mod.contiguous(), **kwargs, ) - return out.transpose(1, 2).contiguous() + return finalize(out.transpose(1, 2).contiguous()) __all__ = [ diff --git a/unsloth/utils/attnres.py b/unsloth/utils/attnres.py new file mode 100644 index 0000000000..19fe08d7c5 --- /dev/null +++ b/unsloth/utils/attnres.py @@ -0,0 +1,102 @@ +# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see . + +"""Optional attention residual hooks used by attention dispatch.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Callable, Optional + +from torch import Tensor + +AttnResKernelHook = Callable[[Tensor], Tensor] + + +@dataclass +class AttnResConfig: + enabled: bool = False + kernel_hook: Optional[AttnResKernelHook] = None + + +@dataclass +class AttnResState: + config: AttnResConfig = field(default_factory = AttnResConfig) + kernel_hook: Optional[AttnResKernelHook] = None + enabled: bool = False + + def __post_init__(self): + if self.kernel_hook is None: + self.kernel_hook = self.config.kernel_hook + self.enabled = bool( + self.enabled or self.config.enabled or self.kernel_hook is not None + ) + + def begin(self) -> "AttnResState": + self.enabled = bool( + self.enabled or self.config.enabled or self.kernel_hook is not None + ) + return self + + def end(self) -> None: + self.enabled = False + + def apply(self, output: Tensor) -> Tensor: + return apply_attnres(output, self) + + +def begin_attnres( + config: Optional[AttnResConfig] = None, + *, + kernel_hook: Optional[AttnResKernelHook] = None, +) -> AttnResState: + if config is None: + config = AttnResConfig() + state = AttnResState(config = config, kernel_hook = kernel_hook) + return state.begin() + + +def end_attnres(state: Optional[AttnResState]) -> None: + if state is None: + return + state.end() + + +def apply_attnres( + output: Tensor, + state: Optional[AttnResState] = None, + *, + kernel_hook: Optional[AttnResKernelHook] = None, +) -> Tensor: + if state is None or not state.enabled: + return output + + hook = kernel_hook + if hook is None: + hook = state.kernel_hook + if hook is None: + hook = state.config.kernel_hook + if hook is None: + return output + return hook(output) + + +__all__ = [ + "AttnResConfig", + "AttnResState", + "begin_attnres", + "end_attnres", + "apply_attnres", +]