Skip to content
Closed
7 changes: 7 additions & 0 deletions unsloth/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@
create_flex_attention_causal_mask,
create_flex_attention_sliding_window_mask,
)
from ..utils.attnres import (
AttnResConfig,
AttnResState,
apply_attnres,
begin_attnres,
end_attnres,
)

import os

Expand Down
207 changes: 207 additions & 0 deletions unsloth/models/attnres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass, field
import math
import os
from typing import Optional

import torch


def _read_flag(value, default = False):
if value is None:
return default
if isinstance(value, str):
value = value.strip().lower()
return value in ("1", "true", "yes", "on")
return bool(value)


def _read_int(value, default):
try:
parsed = int(value)
except (ValueError, TypeError):
return default
Comment thread
kleeedolinux marked this conversation as resolved.
return max(1, parsed)


def _read_float(value, default):
try:
return float(value)
except (ValueError, TypeError):
return default
Comment thread
kleeedolinux marked this conversation as resolved.


@dataclass
class _AttnResState:
enabled: bool
block_size: int
alpha: float
num_layers: int
completed_block_summaries: list[torch.Tensor] = field(default_factory = list)
current_block_states: list[torch.Tensor] = field(default_factory = list)


def _get_config_value(config, names, default = None):
for name in names:
if hasattr(config, name):
value = getattr(config, name)
if value is not None:
return value
return default


def _build_state(model, *, use_cache = False):
config = getattr(model, "config", None)
if config is None:
return None

enabled = _read_flag(
_get_config_value(
config,
(
"attnres",
"attn_residual",
"attn_residuals",
"use_attnres",
"use_attn_residuals",
"_attnres",
"_use_attnres",
),
default = None,
),
default = False,
)
if not enabled:
enabled = _read_flag(os.environ.get("UNSLOTH_ATTNRES"), default = False)
if not enabled:
enabled = _read_flag(os.environ.get("ATTNRES"), default = False)
if not enabled:
return None

# Stateful accumulation can desync across recomputation passes.
if (
getattr(model, "training", False)
and getattr(model, "gradient_checkpointing", False)
and not use_cache
):
return None

block_size = _read_int(
_get_config_value(
config,
("attnres_block_size", "attn_residual_block_size", "attnres_block"),
default = os.environ.get("UNSLOTH_ATTNRES_BLOCK_SIZE", 8),
),
default = 8,
)
alpha = _read_float(
_get_config_value(
config,
("attnres_alpha", "attn_residual_alpha"),
default = os.environ.get("UNSLOTH_ATTNRES_ALPHA", 1.0),
),
default = 1.0,
)
num_layers = int(getattr(config, "num_hidden_layers", 0))
if num_layers <= 0 and hasattr(model, "layers"):
num_layers = len(model.layers)
return _AttnResState(
enabled = True,
block_size = block_size,
alpha = alpha,
num_layers = num_layers,
)


def begin_attnres_state(model, *, use_cache = False):
return _build_state(model, use_cache = use_cache)


def attnres_init_forward_state(
model,
hidden_states = None,
attention_mask = None,
position_ids = None,
past_key_values = None,
use_cache = False,
output_attentions = False,
output_hidden_states = False,
):
return begin_attnres_state(model, use_cache = use_cache)


def _compute_residual_mix(
query: torch.Tensor,
candidates: list[torch.Tensor],
) -> torch.Tensor:
if len(candidates) == 0:
return torch.zeros_like(query)

# (bsz, seqlen, n_states, dim)
stacked = torch.stack(candidates, dim = 2)
Comment thread
kleeedolinux marked this conversation as resolved.
dim = query.shape[-1]
scale = 1.0 / math.sqrt(float(dim))
logits = (query.unsqueeze(2) * stacked).sum(dim = -1) * scale
weights = torch.softmax(logits, dim = -1)
return (weights.unsqueeze(-1) * stacked).sum(dim = 2)


def attnres_transform_attention_output(
attention_output: torch.Tensor,
attnres_state: Optional[_AttnResState] = None,
attnres_layer_idx: Optional[int] = None,
residual: Optional[torch.Tensor] = None,
attention_mask = None,
causal_mask = None,
position_ids = None,
):
if attnres_state is None or not getattr(attnres_state, "enabled", False):
return attention_output

layer_idx = 0 if attnres_layer_idx is None else int(attnres_layer_idx)
query = residual if residual is not None else attention_output

candidates = list(attnres_state.completed_block_summaries)
candidates.extend(attnres_state.current_block_states)
if len(candidates) != 0:
mixed = _compute_residual_mix(query, candidates)
attention_output = attention_output + (attnres_state.alpha * mixed)

# Keep current-layer information for future layers in this block.
attnres_state.current_block_states.append(attention_output.clone())

# Finalize a block summary at boundaries.
at_block_end = ((layer_idx + 1) % attnres_state.block_size) == 0
at_model_end = (attnres_state.num_layers > 0) and (
(layer_idx + 1) >= attnres_state.num_layers
)
if at_block_end or at_model_end:
block_summary = torch.stack(attnres_state.current_block_states, dim = 0).sum(
dim = 0
)
attnres_state.completed_block_summaries.append(block_summary)
attnres_state.current_block_states.clear()

return attention_output


__all__ = [
"begin_attnres_state",
"attnres_init_forward_state",
"attnres_transform_attention_output",
]
38 changes: 38 additions & 0 deletions unsloth/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .llama import *
from .llama import attnres_init_forward_state, attnres_transform_attention_output
from ._utils import __version__
from unsloth_zoo.hf_utils import dtype_from_config
from unsloth_zoo.utils import _get_dtype, Version
Expand Down Expand Up @@ -184,6 +185,8 @@ def CohereDecoderLayer_fast_forward(
*args,
**kwargs,
):
attnres_state = kwargs.get("attnres_state", None)
attnres_layer_idx = kwargs.get("attnres_layer_idx", None)
if use_cache and hasattr(
self, "_flag_for_generation"
): # past_key_value is not None:
Expand All @@ -209,6 +212,15 @@ def CohereDecoderLayer_fast_forward(
padding_mask = padding_mask,
**kwargs,
)
hidden_states_attention = attnres_transform_attention_output(
hidden_states_attention,
attnres_state = attnres_state,
attnres_layer_idx = attnres_layer_idx,
residual = residual,
attention_mask = attention_mask,
causal_mask = causal_mask,
position_ids = position_ids,
)

# Fully Connected
hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
Expand All @@ -229,6 +241,15 @@ def CohereDecoderLayer_fast_forward(
padding_mask = padding_mask,
**kwargs,
)
hidden_states_attention = attnres_transform_attention_output(
hidden_states_attention,
attnres_state = attnres_state,
attnres_layer_idx = attnres_layer_idx,
residual = residual,
attention_mask = attention_mask,
causal_mask = causal_mask,
position_ids = position_ids,
)

# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
Expand Down Expand Up @@ -468,6 +489,15 @@ def CohereModel_fast_forward_inference(
else:
attention_mask = None

attnres_state = attnres_init_forward_state(
self,
hidden_states = hidden_states,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
use_cache = True,
)

next_decoder_cache = []
for idx, decoder_layer in enumerate(self.model.layers):
device_index = getattr(decoder_layer, "_per_layer_device_index", 0)
Expand All @@ -488,6 +518,14 @@ def CohereModel_fast_forward_inference(
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
)
)
hidden_states_attention = attnres_transform_attention_output(
hidden_states_attention,
attnres_state = attnres_state,
attnres_layer_idx = idx,
residual = residual,
attention_mask = attention_mask,
position_ids = position_ids,
)

hidden_states_mlp = fast_swiglu_inference(decoder_layer.mlp, hidden_states)
residual += hidden_states_attention
Expand Down
38 changes: 38 additions & 0 deletions unsloth/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from .llama import *
from .llama import attnres_init_forward_state, attnres_transform_attention_output
import os
from ._utils import __version__
from unsloth_zoo.utils import Version, _get_dtype
Expand Down Expand Up @@ -423,6 +424,8 @@ def FalconH1DecoderLayer_fast_forward(
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
attnres_state = kwargs.get("attnres_state", None)
attnres_layer_idx = kwargs.get("attnres_layer_idx", None)
if use_cache and hasattr(self, "_flag_for_generation"):
residual = hidden_states
hidden_states = fast_rms_layernorm_inference(
Expand All @@ -441,6 +444,15 @@ def FalconH1DecoderLayer_fast_forward(
**kwargs,
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
attention_hidden_states = attnres_transform_attention_output(
attention_hidden_states,
attnres_state = attnres_state,
attnres_layer_idx = attnres_layer_idx,
residual = residual,
attention_mask = attention_mask,
causal_mask = causal_mask,
position_ids = position_ids,
)

mamba_hidden_states = self.mamba(
hidden_states = hidden_states,
Expand Down Expand Up @@ -486,6 +498,15 @@ def FalconH1DecoderLayer_fast_forward(
**kwargs,
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
attention_hidden_states = attnres_transform_attention_output(
attention_hidden_states,
attnres_state = attnres_state,
attnres_layer_idx = attnres_layer_idx,
residual = residual,
attention_mask = attention_mask,
causal_mask = causal_mask,
position_ids = position_ids,
)

hidden_states = mamba_hidden_states + attention_hidden_states

Expand Down Expand Up @@ -560,6 +581,15 @@ def FalconH1Model_fast_forward_inference_custom(
else:
attention_mask = None

attnres_state = attnres_init_forward_state(
self,
hidden_states = X,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
use_cache = True,
)

next_decoder_cache = []

for idx, decoder_layer in enumerate(self.model.layers):
Expand All @@ -584,6 +614,14 @@ def FalconH1Model_fast_forward_inference_custom(
attention_hidden_states = (
attention_hidden_states * decoder_layer.attn_out_multiplier
)
attention_hidden_states = attnres_transform_attention_output(
attention_hidden_states,
attnres_state = attnres_state,
attnres_layer_idx = idx,
residual = residual,
attention_mask = attention_mask,
position_ids = position_ids,
)
mamba_hidden_states = decoder_layer.mamba(
hidden_states = X,
cache_params = present_key_value,
Expand Down
Loading