diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py index 6ff0e26a4ee..34a7e8cfcf8 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/configuration_decilm.py @@ -27,13 +27,11 @@ # fakes imports to make AutoConfig infer dependencies from .transformers_4_44_2__modeling_rope_utils import rope_config_validation -from .transformers_4_51_3__cache_utils import HybridChunkedCache from .transformers_4_51_3__configuration_llama4 import Llama4Config # make sure that auto-formatting doesn't remove the fake imports rope_config_validation Llama4Config -HybridChunkedCache class DeciLMConfig(LlamaConfig): diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py deleted file mode 100644 index 76dbb3473b6..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/megatron_lm__mamba_mixer.py +++ /dev/null @@ -1,527 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (c) 2024, Tri Dao, Albert Gu. - -# Adapted from megatron.core.ssm.mamba_mixer.MambaMixer: -# https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/0b5140009fb9011eceaef6d36ea1181a8d176479/megatron/core/ssm/mamba_mixer.py - -# ruff: noqa: N803, N806 - -# Some of this code was adopted from https://github.com/state-spaces/mamba/ -# This source code is licensed under the Apache license found in the -# LICENSE file in the root directory of this source tree. - -import math -import warnings - -import torch -import torch.nn as nn -import torch.nn.functional as F - -try: - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update - from einops import rearrange, repeat - from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated - from mamba_ssm.ops.triton.selective_state_update import selective_state_update - from mamba_ssm.ops.triton.ssd_combined import ( - mamba_chunk_scan_combined, - mamba_split_conv1d_scan_combined, - ) - - class MambaMixerMegatron(nn.Module): - """ - Args: - d_model: The hidden size of the model. - d_state: The state size of the SSM. - d_conv: The number of channels in the causal convolution. - conv_init: The initialization range for the causal convolution weights. - nheads: The number of Mamba heads. Used to calculate the expansion factor for the SSM - instead of the deprecated arg "expand". - headdim: The hidden size of each attention head. - ngroups: The number of attention heads. - A_init_range: The initialization range for the attention weights. - D_has_hdim: Whether the D parameter has the same number of dimensions as the hidden - state. - rmsnorm: Whether to use root mean square normalization. - norm_before_gate: Whether to apply normalization before the gating mechanism. - dt_min: The minimum value of the dt parameter. - dt_max: The maximum value of the dt parameter. - dt_init: The initialization value of the dt parameter. - dt_scale: The scaling factor for the dt parameter. - dt_init_floor: The minimum value of the dt parameter after initialization. - bias: Whether to use bias in the linear layers. - conv_bias: Whether to use bias in the causal convolution. - chunk_size: The chunk size for the fused kernel. - use_mem_eff_path: Whether to use the memory-efficient path for the Mamba model. - layer_number: The layer number of this Mamba layer. - """ - - def __init__( - self, - d_model, - d_state=256, - d_conv=4, - conv_init=None, - nheads=256, - headdim=64, - ngroups=8, - A_init_range=(1, 16), - D_has_hdim=False, - rmsnorm=True, - norm_before_gate=False, - dt_min=0.001, - dt_max=0.1, - dt_init="random", - dt_scale=1.0, - dt_init_floor=1e-4, - bias=False, - conv_bias=True, - # Fused kernel and sharding options - chunk_size=128, - use_mem_eff_path=True, - layer_number=None, - ): - super().__init__() - self.d_model = d_model - self.d_state = d_state - self.d_conv = d_conv - self.conv_init = conv_init - self.nheads = nheads - self.headdim = headdim - self.ngroups = ngroups - self.D_has_hdim = D_has_hdim - self.rmsnorm = rmsnorm - self.norm_before_gate = norm_before_gate - self.chunk_size = chunk_size - self.use_mem_eff_path = use_mem_eff_path - self.layer_number = layer_number - - self.d_inner = self.nheads * self.headdim - - self.tensor_model_parallel_size = 1 - assert self.d_inner % self.tensor_model_parallel_size == 0 - assert self.ngroups % self.tensor_model_parallel_size == 0 - assert self.nheads % self.tensor_model_parallel_size == 0 - assert not bias - assert not self.norm_before_gate - - self.d_inner_local = self.d_inner // self.tensor_model_parallel_size - self.ngroups_local = self.ngroups // self.tensor_model_parallel_size - self.nheads_local = self.nheads // self.tensor_model_parallel_size - - assert self.d_inner_local % self.ngroups_local == 0 - - # Assume sequence parallelism: input is already partitioned along the - # sequence dimension - self.in_proj = nn.Linear( - self.d_model, - self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads, # AB CD E - bias=False, - ) - - conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state # A CD - - # weight dim: [conv_dim, conv_dim, d_conv] - self.conv1d = nn.Conv1d( - in_channels=conv_dim, - out_channels=conv_dim, - bias=conv_bias, - kernel_size=d_conv, - groups=conv_dim, - padding=d_conv - 1, - ) - - if self.conv_init is not None: - nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) - - self.activation = "silu" - self.act = nn.SiLU() - - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt = torch.exp( - torch.rand(self.nheads_local) * (math.log(dt_max) - math.log(dt_min)) - + math.log(dt_min) - ).clamp(min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.dt_bias = nn.Parameter(inv_dt) - # Our initialization would set all Linear.bias to zero, - # need to mark this one as _no_reinit - self.dt_bias._no_reinit = True - # Just to be explicit. Without this we already don't - # put wd on dt_bias because of the check - - # name.endswith("bias") in param_grouping.py - self.dt_bias._no_weight_decay = True - - assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] - A = torch.empty(self.nheads_local, dtype=torch.float32).uniform_(*A_init_range) - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True - - # D "skip" parameter - self.D = nn.Parameter( - torch.ones( - self.d_inner_local if self.D_has_hdim else self.nheads_local, - ) - ) # Keep in fp32 - self.D._no_weight_decay = True - - if self.rmsnorm: - self.norm = RMSNormGated( - self.d_inner_local, - eps=1e-5, - group_size=self.d_inner_local // self.ngroups_local, - norm_before_gate=self.norm_before_gate, - ) - - # Assume sequence parallelism: input is partitioned along d_inner and - # output is partitioned along the sequence dimension - self.out_proj = nn.Linear( - self.d_inner, - self.d_model, - bias=False, - ) - - def forward(self, hidden_states, inference_params=None): - """ - hidden_states: (nL, B, D) / (L B D) - Returns: same shape as hidden_states - """ - _, batch, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - # assert not self.config.sequence_parallel - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, out_bias, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out, out_bias - - # (nheads_local) - A = -torch.exp(self.A_log.float()) - - # xz, _ = self.in_proj(hidden_states) # TransformerEngine also returns bias - xz = self.in_proj(hidden_states) - - # transpose: l b pd --> b l pd - xz = rearrange(xz, "l b d -> b l d").contiguous() - - if self.use_mem_eff_path and inference_params is None: - assert ssm_state is None - - if self.conv1d.bias is not None: - self.conv1d.bias.data_ptr() - - y = mamba_split_conv1d_scan_combined( - xz, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.dt_bias.float(), - A, - D=( - rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.D - ), - chunk_size=self.chunk_size, - activation=self.activation, - headdim=None if self.D_has_hdim else self.headdim, - ngroups=self.ngroups_local, - norm_before_gate=self.norm_before_gate, - ) - - if self.rmsnorm: - y = self.norm(y) - else: - z, xBC, dt = torch.split( - xz, - [ - self.d_inner_local, - self.d_inner_local + 2 * self.ngroups_local * self.d_state, - self.nheads_local, - ], - dim=-1, - ) - - # transpose: b l pd --> b pd l - xBC = rearrange(xBC, "b l d -> b d l").contiguous() - - # Compute short convolution - if conv_state is not None: - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state.copy_( - F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) - ) # Update state (B D W) - - seqlen = xBC.size(2) - if causal_conv1d_fn is None: - xBC = self.act(self.conv1d(xBC)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - xBC = causal_conv1d_fn( - x=xBC, - weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), - bias=self.conv1d.bias, - activation=self.activation, - ) - - # transpose b pd l --> b l pd - xBC = rearrange(xBC, "b d l -> b l d").contiguous() - - x, B, C = torch.split( - xBC, - [ - self.d_inner_local, - self.ngroups_local * self.d_state, - self.ngroups_local * self.d_state, - ], - dim=-1, - ) - - # TO DO Vijay: fuse most of the transposes with the GEMMS - x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim).contiguous() - dt = dt.contiguous() - B = rearrange(B, "b l (g n) -> b l g n", n=self.d_state).contiguous() - C = rearrange(C, "b l (g n) -> b l g n", n=self.d_state).contiguous() - z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim).contiguous() - y = mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - self.chunk_size, - D=( - rearrange(self.D.float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.D - ), - z=z if not self.rmsnorm else None, - dt_bias=self.dt_bias.float(), - dt_softplus=True, - return_final_states=ssm_state is not None, - ) - - if ssm_state is not None: - y, last_state = y - ssm_state.copy_(last_state) - - if self.rmsnorm: - y = rearrange(y, "b l h p -> b l (h p)").contiguous() - z = rearrange(z, "b l h p -> b l (h p)").contiguous() - y = self.norm(y, z) - else: - y = rearrange(y, "b l h p -> b l (h p)").contiguous() - - y = rearrange(y, "b l d -> l b d").contiguous() - # out, out_bias = self.out_proj(y) # TransformerEngine also returns bias - out = self.out_proj(y) - - return out - - def step(self, hidden_states, conv_state, ssm_state): - """ - Performs inference step for decoding - """ - # assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now" - dtype = hidden_states.dtype - assert hidden_states.shape[0] == 1, ( - "Only support decoding with 1 token at a time for now" - ) - - # l b d --> b d - hidden_states = hidden_states.squeeze(0) - - # b d_model --> b p(2d) - xz, _ = self.in_proj(hidden_states) - - z, xBC, dt = torch.split( - xz, - [ - self.d_inner_local, - self.d_inner_local + 2 * self.ngroups_local * self.d_state, - self.nheads_local, - ], - dim=-1, - ) - - # Conv step - if causal_conv1d_update is None: - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = xBC - xBC = torch.sum( - conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1 - ) # (B D) - if self.conv1d.bias is not None: - xBC = xBC + self.conv1d.bias - xBC = self.act(xBC).to(dtype=dtype) - else: - xBC = causal_conv1d_update( - xBC, - conv_state, - rearrange(self.conv1d.weight, "d 1 w -> d w"), - self.conv1d.bias, - self.activation, - ) - - x, B, C = torch.split( - xBC, - [ - self.d_inner_local, - self.ngroups_local * self.d_state, - self.ngroups_local * self.d_state, - ], - dim=-1, - ) - A = -torch.exp(self.A_log.float()) - - # SSM step - if selective_state_update is None: - if self.ngroups_local > 1: - B = rearrange(B, "b (g n) -> b g n", n=self.d_state) - C = rearrange(C, "b (g n) -> b g n", n=self.d_state) - B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) - C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local) - - dt = repeat(dt, "b h -> b (h p)", p=self.headdim) - dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim) - A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state) - D = repeat(self.D, "h -> (h p)", p=self.headdim) - - dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype)) - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - - dB_x = torch.einsum("bd,bdn,bd->bdn", dt, B, x) - ssm_state.copy_( - ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim) - + rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim) - ) - - y = torch.einsum( - "bdn,bdn->bd", - rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim), - C, - ) - y = y + D.to(dtype) * x - if not self.rmsnorm: - y = y * self.act(z) # (B D) - else: - # Discretize A and B (b (g n)) - dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads) - dA = torch.exp(dt * A) - x = rearrange(x, "b (h p) -> b h p", p=self.headdim) - dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x) - ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx) - y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C) - y = y + rearrange(self.D.to(dtype), "h -> h 1") * x - y = rearrange(y, "b h p -> b (h p)") - if not self.rmsnorm: - y = y * self.act(z) # (B D) - else: - A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32) - dt = repeat(dt, "b h -> b h p", p=self.headdim) - dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim) - D = repeat(self.D, "h -> h p", p=self.headdim) - B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local) - C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local) - x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim) - if not self.rmsnorm: - z = rearrange(z, "b (h p) -> b h p", p=self.headdim) - y = selective_state_update( - ssm_state, - x_reshaped, - dt, - A, - B, - C, - D, - z=z if not self.rmsnorm else None, - dt_bias=dt_bias, - dt_softplus=True, - ) - y = rearrange(y, "b h p -> b (h p)") - - if self.rmsnorm: - y = self.norm(y, z) - - # b pd --> b d - out, out_bias = self.out_proj(y) - return out.unsqueeze(0), out_bias, conv_state, ssm_state - - def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): - """ - allocate inference cache - """ - device = self.out_proj.weight.device - conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype - conv_state = torch.zeros( - batch_size, - self.conv1d.weight.shape[0], - self.d_conv, - device=device, - dtype=conv_dtype, - ) - ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype - # ssm_dtype = torch.float32 - ssm_state = torch.zeros( - batch_size, - self.nheads_local, - self.headdim, - self.d_state, - device=device, - dtype=ssm_dtype, - ) - return conv_state, ssm_state - - def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False): - assert self.layer_number is not None - if self.layer_number not in inference_params.key_value_memory_dict: - conv_state = torch.zeros( - batch_size, - self.conv1d.weight.shape[0], - self.d_conv, - device=self.conv1d.weight.device, - dtype=self.conv1d.weight.dtype, - ) - ssm_state = torch.zeros( - batch_size, - self.nheads_local, - self.headdim, - self.d_state, - device=self.in_proj.weight.device, - dtype=self.in_proj.weight.dtype, - ) - inference_params.key_value_memory_dict[self.layer_number] = (conv_state, ssm_state) - else: - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_number] - # TO DO: What if batch size changes between generation, and we reuse the same states? - if initialize_states: - conv_state.zero_() - ssm_state.zero_() - return conv_state, ssm_state - -except ImportError as exception: - mamba_error_message = f"Cannot declare MambaMixer due to missing dependencies: {exception=}." - warnings.warn(mamba_error_message) - - # TODO: Investigate why this type ignore is needed - class MambaMixerMegatron(nn.Module): # type: ignore[no-redef] - def __init__(self, *args, **kwargs): - raise ImportError(mamba_error_message) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py index 84496bc4a3d..0102fc3a95c 100644 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py +++ b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/modeling_decilm.py @@ -15,123 +15,19 @@ # Copyright 2024 Nvidia Corporation, Google Inc, HuggingFace Inc, EleutherAI. All rights reserved. # -# This code for Nvidia's model is based on the Llama modeling code by HuggingFace, -# which is in turn based on EleutherAI's GPT-NeoX library and the GPT-NeoX and -# OPT implementations in this library. -# Sliding window code based on Gemma2 by Google. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Pared-down DeciLM building blocks for Model-Optimizer puzzletron / AnyModel flows. +# The full HF DeciLM decoder stack (decoder layers, attention, rope, etc.) is not vendored here; +# AnyModel loads real models via transformers. This module keeps shared helpers: RMSNorm, +# gated/vanilla MLP (used by MoE accounting), MoE, and LMHead for replacement / validation code. # mypy: ignore-errors -import math - import torch import torch.nn.functional as F -import torch.utils.checkpoint from torch import nn -from transformers.utils import is_flash_attn_greater_or_equal_2_10, logging -from .block_config import AttentionConfig, FFNConfig, MambaConfig, MoEConfig +from .block_config import FFNConfig, MoEConfig from .configuration_decilm import DeciLMConfig -from .megatron_lm__mamba_mixer import MambaMixerMegatron from .transformers_4_44_2__activations import ACT2FN -from .transformers_4_44_2__cache_utils import Cache, StaticCache -from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter -from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import ( - _flash_attention_forward, -) -from .transformers_4_44_2__modeling_rope_utils import ROPE_INIT_FUNCTIONS -from .transformers_4_44_2__pytorch_utils import ALL_LAYERNORM_LAYERS -from .transformers_4_51_3__modeling_llama4_attention import Llama4TextAttention, Llama4TextConfig -from .variable_cache import VariableCache -from .vllm_yarn_utils import YaRNScalingRotaryEmbedding - -# from transformers.models.llama4.modeling_llama4 import Llama4TextL2Norm -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "DeciLMConfig" - - -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or - a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be - as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -class Llama4TextL2Norm(torch.nn.Module): - def __init__(self, eps: float = 1e-6): - super().__init__() - self.eps = eps - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - return self._norm(x.float()).type_as(x) - - def extra_repr(self): - return f"eps={self.eps}" class DeciLMRMSNorm(nn.Module): @@ -154,349 +50,10 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -ALL_LAYERNORM_LAYERS.append(DeciLMRMSNorm) - - -class DeciLMRotaryEmbedding(nn.Module): - def __init__( - self, - dim=None, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0, - rope_type="default", - config: DeciLMConfig | None = None, - ): - super().__init__() - # TODO (joao): remove the `if` below, only used for BC - self.rope_kwargs = {} - if config is None: - logger.warning_once( - "`DeciLMRotaryEmbedding` can now be fully parameterized by passing the model config through the " - "`config` argument. All other arguments will be removed in v4.45" - ) - self.rope_kwargs = { - "rope_type": rope_type, - "factor": scaling_factor, - "dim": dim, - "base": base, - "max_position_embeddings": max_position_embeddings, - } - self.rope_type = rope_type - self.max_seq_len_cached = max_position_embeddings - self.original_max_seq_len = max_position_embeddings - else: - # BC: "rope_type" was originally "type" - if config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get( - "rope_type", config.rope_scaling.get("type") - ) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_impl = "rope" if config is None else config.position_embedding_type - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - def _set_inv_freq_if_needed(self, device: torch.device) -> None: - is_missing_inv_freq = not hasattr(self, "inv_freq") - is_meta_mismatch = not is_missing_inv_freq and ( - str(device) != "meta" and self.inv_freq.is_meta - ) - - if is_missing_inv_freq or is_meta_mismatch: - with torch.device(device): - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, **self.rope_kwargs - ) - self.original_inv_freq = inv_freq - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer( - "inv_freq", inv_freq, persistent=False - ) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if ( - seq_len < self.original_max_seq_len - and self.max_seq_len_cached > self.original_max_seq_len - ): # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - self._set_inv_freq_if_needed(x.device) - - if self.rope_impl == "rope_llama4": - return self.llama4_forward(x, position_ids) - else: - return self.llama3_forward(x, position_ids) - - def llama3_forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def llama4_forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar( - torch.ones_like(freqs), freqs - ) # Convert to complex representation - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - freqs_cis = freqs_cis * self.attention_scaling - return freqs_cis - - -class DeciMistralYarnRotaryEmbedding(nn.Module): - def __init__(self, config: DeciLMConfig): - super().__init__() - self.config = config - self.rope_scaling = config.rope_scaling - self.base = config.rope_theta - self.rope_impl = config.position_embedding_type - self.head_size = config.hidden_size // config.num_attention_heads - self.yarn = YaRNScalingRotaryEmbedding( - head_size=self.head_size, - rotary_dim=self.head_size, - max_position_embeddings=self.rope_scaling["original_max_position_embeddings"], - base=self.base, - is_neox_style=True, - scaling_factor=self.rope_scaling["factor"], - beta_fast=self.rope_scaling["beta_fast"], - beta_slow=self.rope_scaling["beta_slow"], - dtype=torch.float32, - ) - self.attention_scaling = self.yarn.mscale - self.scaling_factor = self.rope_scaling["factor"] - self.rope_impl = "rope" if config is None else config.position_embedding_type - self.rope_impl = "even_odd" - - def _set_inv_freq_if_needed(self, device: torch.device) -> None: - is_missing_inv_freq = not hasattr(self, "inv_freq") - is_meta_mismatch = not is_missing_inv_freq and ( - str(device) != "meta" and self.inv_freq.is_meta - ) - - if is_missing_inv_freq or is_meta_mismatch: - with torch.device(device): - inv_freq = self.yarn._compute_inv_freq(self.scaling_factor) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - def halves_forward(self, x, position_ids): - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - - self._set_inv_freq_if_needed(x.device) - - # print(f"halves_forward") - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - inv_freq_expanded = inv_freq_expanded.to(x.device) - # print(f"inv_freq_expanded: {inv_freq_expanded.device}") - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def forward(self, x, position_ids): - if self.rope_impl == "halves": - return self.halves_forward(x, position_ids) - elif self.rope_impl == "even_odd": - return self.even_odd_forward(x, position_ids) - else: - raise ValueError(f"Invalid rope implementation: {self.rope_impl}") - - def even_odd_forward(self, x, position_ids): - device_type = x.device.type - device_type = ( - device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - ) - - self._set_inv_freq_if_needed(x.device) - - # print(f"even_odd_forward") - # Core RoPE block - inv_freq_expanded = ( - self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar( - torch.ones_like(freqs), freqs - ) # Convert to complex representation - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - freqs_cis = freqs_cis * self.attention_scaling - return freqs_cis - - -class DeciLMLinearScalingRotaryEmbedding(DeciLMRotaryEmbedding): - """DeciLMRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`DeciLMLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`DeciLMRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)." - ) - kwargs["rope_type"] = "linear" - super().__init__(*args, **kwargs) - - -class DeciLMDynamicNTKScalingRotaryEmbedding(DeciLMRotaryEmbedding): - """DeciLMRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, *args, **kwargs): - logger.warning_once( - "`DeciLMDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.45. Please use " - "`DeciLMRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to " - "__init__)." - ) - kwargs["rope_type"] = "dynamic" - super().__init__(*args, **kwargs) - - -rope_type_to_class = { - "default": DeciLMRotaryEmbedding, - "linear": DeciLMLinearScalingRotaryEmbedding, - "dynamic": DeciLMDynamicNTKScalingRotaryEmbedding, - "rope_llama4": DeciLMRotaryEmbedding, - "rope": DeciLMRotaryEmbedding, - "mistral_yarn": DeciMistralYarnRotaryEmbedding, -} - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, freqs_cis, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - freqs_cis (`torch.Tensor`): The frequency tensor. - a tuple of two tensors, cos and sin. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - # print(f"applying first half-second half") - cos, sin = freqs_cis - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def vllm_apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - # print(f"freqs_cis: {freqs_cis.shape}, xq_: {xq_.shape}, xk_: {xk_.shape}") - xq_out = torch.view_as_real(xq_ * freqs_cis[:, None, :, :]).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis[:, None, :, :]).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) +def sparsity_backward_hook(*args, **kwargs): + raise NotImplementedError( + "No support for sparsity when training HF DeciLM (inference is ok though)" + ) class DeciLMGatedMLP(nn.Module): @@ -545,1040 +102,6 @@ def forward(self, x): return down_proj -class DeciLMVanillaMLP(nn.Module): - def __init__( - self, - config: DeciLMConfig, - ffn_config: FFNConfig, - ): - super().__init__() - self.config = config - self.ffn_config = ffn_config - self.hidden_size = config.hidden_size - self.intermediate_size = ffn_config.intermediate_size - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) - self.act_fn = ACT2FN[getattr(ffn_config, "hidden_act", "silu")] - - if ffn_config.sparsify is not None: - self.register_full_backward_hook(sparsity_backward_hook) - - assert self.config.pretraining_tp == 1, ( - "Unsupported pretraining_tp != 1 for DeciLMVanillaMLP" - ) - - def forward(self, x): - return self.down_proj(self.act_fn(self.up_proj(x))) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand( - batch, num_key_value_heads, n_rep, slen, head_dim - ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class DeciLMAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__( - self, - config: DeciLMConfig, - attention_config: AttentionConfig, - layer_idx: int | None = None, - ): - super().__init__() - self.config = config - self.attention_config = attention_config # type: AttentionConfig - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - if config.head_dim is not None: - self.head_dim = config.head_dim - else: - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_groups = attention_config.n_heads_in_group # DeciLM-specific code - self.num_key_value_heads = ( - self.num_heads // self.num_key_value_groups - ) # DeciLM-specific code - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - # llama4 attention specific - self.llama4_attn_config = attention_config.llama4 - - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=config.o_proj_bias - ) - - if self.config.position_embedding_type in ["rope", "rope_llama4", "mistral_yarn"]: - # TO DO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) - self.rotary_emb = rope_type_to_class[self.config.position_embedding_type]( - config=self.config - ) - - if attention_config.sparsify is not None: - self.register_full_backward_hook(sparsity_backward_hook) - - self.is_llama4 = self.llama4_attn_config is not None - if ( - self.is_llama4 - and self.llama4_attn_config.use_qk_norm - and self.llama4_attn_config.use_rope - ): - self.qk_norm = Llama4TextL2Norm(self.config.rms_norm_eps) - - self.use_rope = ( - self.llama4_attn_config.use_rope - if self.is_llama4 - else self.config.position_embedding_type in ["rope", "mistral_yarn"] - ) - self.rope_impl = self.rotary_emb.rope_impl - self.apply_rope_fn = ( - apply_rotary_emb - if self.rope_impl in ["even_odd", "rope_llama4"] - else apply_rotary_pos_emb - ) - # self.apply_rope_fn = apply_rotary_emb - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] - | None = None, # will become mandatory in v4.45 - **kwargs, - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - bsz, q_len, _ = hidden_states.size() - input_shape = hidden_states.shape[:-1] - - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * self.head_dim - ) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - if self.use_rope: - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE " - "embeddings internally through `position_ids` (2D tensor with the indexes of the " - "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " - "containing cos and sin). In v4.45 `position_ids` will be removed and " - "`position_embeddings` will be mandatory." - ) - freqs_cis = self.rotary_emb(value_states, position_ids) - else: - freqs_cis = position_embeddings - - query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) - - if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm - query_states = self.qk_norm(query_states) - key_states = self.qk_norm(key_states) - - if self.is_llama4: - query_states = self.apply_attention_scaling(input_shape, cache_position, query_states) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - # print(f"cache_position: {cache_position}") - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( - self.head_dim - ) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( - query_states.dtype - ) - attn_weights = nn.functional.dropout( - attn_weights, p=self.attention_dropout, training=self.training - ) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, -1) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1 - ) - attn_output = sum( - [ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ] - ) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def apply_attention_scaling(self, input_shape, cache_position, query_states): - # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers - if self.llama4_attn_config.attn_temperature_tuning and not self.use_rope: - attn_scales = ( - torch.log( - torch.floor( - (cache_position.float() + 1.0) / self.llama4_attn_config.floor_scale - ) - + 1.0 - ) - * self.llama4_attn_config.attn_scale - + 1.0 - ) - attn_scales = attn_scales.view((*input_shape, 1, 1)).transpose(1, 2) - query_states = (query_states * attn_scales).to(query_states.dtype) - return query_states - return query_states - - -class DeciLMFlashAttention2(DeciLMAttention): - """ - DeciLM flash attention module. This module inherits from `DeciLMAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is - # bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is - # used to handle this difference. - # Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case - # q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - self.sliding_window = self.attention_config.prefill_sliding_window - - self.pre_attention_identity_query = nn.Identity() # for debugging hooks - self.pre_attention_identity_key = nn.Identity() # for debugging hooks - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] - | None = None, # will become mandatory in v4.45 - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose( - 1, 2 - ) - value_states = value_states.view( - bsz, q_len, self.num_key_value_heads, self.head_dim - ).transpose(1, 2) - - if self.config.position_embedding_type in ["rope", "mistral_yarn"]: - # llama4 doesn't use flash attention - if position_embeddings is None: - logger.warning_once( - "The attention layers in this model are transitioning from computing the RoPE " - "embeddings internally through `position_ids` (2D tensor with the indexes of the " - "tokens), to using externally computed `position_embeddings` (Tuple of tensors, " - "containing cos and sin). In v4.45 `position_ids` will be removed and " - "`position_embeddings` will be mandatory." - ) - freqs_cis = self.rotary_emb(value_states, position_ids) - else: - freqs_cis = position_embeddings - - query_states, key_states = self.apply_rope_fn(query_states, key_states, freqs_cis) - # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, freq_cis) - # print(f"applying even odd rope") - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout - # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV - # cache to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (DeciLMRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - query_states = self.pre_attention_identity_query(query_states) - key_states = self.pre_attention_identity_key(key_states) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -DECILM_ATTENTION_CLASSES = { - "eager": DeciLMAttention, - "flash_attention_2": DeciLMFlashAttention2, -} - - -class DeciLMLlama4TextAttention(Llama4TextAttention): - def __init__(self, config: DeciLMConfig, layer_idx: int, attention_config: AttentionConfig): - llama4_text_config = Llama4TextConfig( - hidden_size=config.hidden_size, - num_attention_heads=config.num_attention_heads, - num_key_value_heads=config.num_attention_heads // attention_config.n_heads_in_group, - head_dim=getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), - attn_scale=attention_config.llama4.attn_scale, - floor_scale=attention_config.llama4.floor_scale, - attn_temperature_tuning=attention_config.llama4.attn_temperature_tuning, - attention_dropout=attention_config.llama4.attention_dropout, - use_qk_norm=attention_config.llama4.use_qk_norm, - use_rope=attention_config.llama4.use_rope, - rms_norm_eps=config.rms_norm_eps, - attention_bias=config.attention_bias, - attn_implementation=config.llama4_attn_implementation, - rope_scaling=config.rope_scaling, - max_position_embeddings=config.max_position_embeddings, - attention_chunk_size=attention_config.llama4.attention_chunk_size, - ) - super().__init__(llama4_text_config, layer_idx, use_rope=attention_config.llama4.use_rope) - - -class DeciLMDecoderLayer(nn.Module): - # DeciLM-specific code - def __init__(self, config: DeciLMConfig, layer_idx: int | tuple[int, ...]): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.block_config = config.get_block_config(layer_idx) - - self.attention_config = self.block_config.attention - self.ffn_config = self.block_config.ffn - self.layer_idx = layer_idx - - if not config._attn_implementation: - config._attn_implementation = "eager" - - if not self.attention_config.no_op: - self.input_layernorm = DeciLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - if self.attention_config.replace_with_linear: - self.self_attn = DeciLMLinearAttention(config) - elif self.attention_config.is_mamba: - self.self_attn = DeciLMMambaMixer(config, self.attention_config.mamba) - elif not self.attention_config.is_llama4: - self.self_attn = DECILM_ATTENTION_CLASSES[config._attn_implementation]( - config=config, attention_config=self.attention_config, layer_idx=layer_idx - ) - else: - self.self_attn = DeciLMLlama4TextAttention(config, layer_idx, self.attention_config) - - if not (self.ffn_config.no_op or self.attention_config.is_mamba): - if getattr(self.ffn_config, "hidden_act", None) is None: - print(f"WARNING: FFN hidden_act is None for layer {layer_idx}") - - self.post_attention_layernorm = DeciLMRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - if self.ffn_config.replace_with_linear: - self.mlp = DeciLMLinearMLP(config) - elif self.ffn_config.is_moe: - self.mlp = DeciLMMoe(config, self.ffn_config) - else: - self.mlp = ( - DeciLMGatedMLP(config, self.ffn_config) - if self.ffn_config.gated - else DeciLMVanillaMLP(config, self.ffn_config) - ) - - self.is_sliding = self.attention_config.is_sliding - self.sliding_window = self.attention_config.prefill_sliding_window - self.return_only_hidden_states = self.config.block_return_only_hidden_states - - @property - def device(self): - try: - return next(self.parameters()).device - except StopIteration: - return None - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_value: Cache | None = None, - output_attentions: bool | None = False, - output_router_logits: bool | None = False, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] - | None = None, # necessary, but kept here for BC - **kwargs, - ) -> ( - tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None] - | torch.FloatTensor - ): - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ - paramz = list(self.parameters()) - device = paramz[0].device if len(paramz) > 0 else None - if isinstance(hidden_states, tuple): - # could happen when sewing kit sends the output of the previous layer - # to this layer without going through the model forward unpacking code. - # can be avoided by using config.block_return_only_hidden_states=True - hidden_states = hidden_states[0] - - hidden_states = hidden_states.to(device) - - if cache_position is not None: - cache_position = cache_position.to(device) - - if self.attention_config.llama4 is not None: - # chunk_size = self.attention_config.llama4.attention_chunk_size - # print(f"pre-llama4_update: {attention_mask=}") - # causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( - # attention_mask, hidden_states, cache_position, past_key_value, output_attentions, use_cache=use_cache, - # ) - # attention_mask = causal_mask if (chunk_size is None) else chunk_causal_mask - # if (past_key_value is not None) and isinstance(attention_mask, BlockMask): - # print(f"pre-adjust: {attention_mask.shape=}") - # print(f"pre-adjust: {hidden_states.shape=}") - # print(f"pre-adjust: {past_key_value.get_seq_length()=}") - # q_len = hidden_states.shape[1] - # kv_len = past_key_value.get_seq_length() - # if kv_len == 0: - # kv_len = q_len - # print(f"pre-adjust: {kv_len=} {q_len=}") - # print(f"post-adjust: {attention_mask.shape=}") - assert self.config.llama4_attn_implementation != "flex_attention", ( - "We have a mask issue with flex attention" - ) - - causal_mask, chunk_causal_mask = self._llama4_update_causal_mask( - attention_mask, - hidden_states, - cache_position, - past_key_value, - output_attentions, - use_cache=use_cache, - ) - is_chunked = self.attention_config.llama4.attention_chunk_size is not None - attention_mask = ( - chunk_causal_mask if is_chunked and (chunk_causal_mask is not None) else causal_mask - ) - - else: - attention_mask = self._llama3_update_causal_mask( - attention_mask, hidden_states, cache_position, past_key_value, output_attentions - ) - if self.attention_config.unshifted_sink and self.attention_config.is_sink: - attention_mask = self._unshifted_sink_mask( - attention_mask, - hidden_states, - self.attention_config.window_length, - self.attention_config.num_sink_tokens, - ) - else: - attention_mask = self._gemma2_window_mask( - attention_mask, hidden_states, past_key_value - ) - - self_attn_weights = None - present_key_value = past_key_value - router_logits = None - - if self.attention_config.no_op: - pass - elif self.attention_config.replace_with_linear or self.attention_config.is_mamba: - if self.attention_config.is_mamba: - assert past_key_value is None, "DeciLM does not support generation with Mamba yet" - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn(hidden_states) - hidden_states = residual + hidden_states - else: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - attn_out = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states, self_attn_weights = attn_out[:2] - if len(attn_out) > 2: - present_key_value = attn_out[2] - - hidden_states = residual + hidden_states - - if not self.ffn_config.no_op: - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - - # Handle MoE layers differently as they return router logits - if self.ffn_config.is_moe: - hidden_states, router_logits = self.mlp(hidden_states) - else: - hidden_states = self.mlp(hidden_states) - - hidden_states = residual + hidden_states - - if self.return_only_hidden_states: - return hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits and router_logits is not None: - outputs += (router_logits,) - - return outputs - - def _gemma2_window_mask( - self, - attention_mask: torch.Tensor | None, - hidden_states: torch.Tensor, - past_key_value: VariableCache | None, - ) -> torch.Tensor | None: - if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding - # Flash-attn is a 2D tensor - if self.config._attn_implementation == "flash_attention_2": - if past_key_value is not None: # when decoding - attention_mask = attention_mask[:, -self.sliding_window :] - else: - min_dtype = torch.finfo(hidden_states.dtype).min - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - if attention_mask.shape[-1] <= 1: # when decoding - attention_mask = attention_mask[:, :, :, -self.sliding_window :] - return attention_mask - - def _unshifted_sink_mask( - self, - attention_mask: torch.Tensor, - hidden_states: torch.Tensor, - window_length: int, - num_sink_tokens: int | None, - ) -> torch.Tensor: - assert self.config._attn_implementation == "eager", ( - "Unshifted sink is only supported in 'eager' mode." - ) - assert attention_mask is not None, "The attention mask seems to not be prepared" - - attention_mask = attention_mask.clone() - min_dtype = torch.finfo(hidden_states.dtype).min - - if window_length == 0: - attention_mask = torch.full_like(attention_mask, fill_value=min_dtype) - else: - query_length = attention_mask.shape[-2] - is_decode = query_length == 1 - if is_decode: - attention_mask[:, :, :, :-window_length] = min_dtype - else: - sliding_window_mask = torch.tril( - torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-window_length - ) - attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask) - - attention_mask[:, :, :, :num_sink_tokens] = 0 - return attention_mask - - def _llama3_update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is - # 2D and of dynamic length even when the static KV cache is used. This is an issue for - # torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic - # shapes. (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. - # A workaround is `@torch.compiler.disable`, but this prevents using `fullgraph=True`. - # See more context in https://github.com/huggingface/transformers/pull/29114 - - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - assert not isinstance(past_key_values, StaticCache), "DeciLM does not support StaticCache" - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not using_static_cache - and not output_attentions - ): - if ( - AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ) - and not self.is_sliding - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_length() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @torch.compiler.disable(recursive=False) # the operations in this method are not compilable - def _llama4_update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache | None, - output_attentions: bool = False, - chunked_attention_mask=None, - use_cache=True, - ): - attn_implementation = self.config.llama4_attn_implementation - - if attn_implementation == "flash_attention_2": - if attention_mask is not None and (attention_mask == 0.0).any(): - return ( - attention_mask, - attention_mask, - ) # flash does not support chunked attn TODO support flash - return None, None - - if attn_implementation not in ["sdpa", "flex_attention", "eager"]: - return None, None - - sequence_length = input_tensor.shape[1] - cache_position = cache_position.to(self.device) - attention_chunk_size = self.attention_config.llama4.attention_chunk_size - if attention_chunk_size is None: - # let the function build some chunked mask, we won't use it since it's not a chunked - # attention layer. We still need to know the chunk size for this if statement that - # comes later on: if attn_implementation == "sdpa" and chunked_attention_mask is not None - # otherwise the mask dtype is wrong for sdpa :bufo-wat: - attention_chunk_size = self.config.get_min_attention_chunk_size() - if attention_chunk_size is None: - logger.warning_once( - "Could not infer attention_chunk_size since the model (or the model shard) " - "has no chunked attention, using 8192 as default for mask construction" - ) - attention_chunk_size = 8192 - - first_cache_position = cache_position[0] - - if past_key_values is not None: - full_cache_length = past_key_values.get_max_cache_shape() or sequence_length - else: - full_cache_length = ( - attention_mask.shape[-1] if attention_mask is not None else sequence_length - ) - - cond1 = first_cache_position >= attention_chunk_size - cond2 = (first_cache_position < attention_chunk_size) & ( - first_cache_position + sequence_length > attention_chunk_size - ) - key_length = ( - torch.where( - cond1, - attention_chunk_size + sequence_length - 1, - torch.where(cond2, first_cache_position + sequence_length, attention_chunk_size), - ) - if use_cache - else full_cache_length - ) - - if attn_implementation == "flex_attention": - raise NotImplementedError("DeciLM Llama4 does not support flex attention") - # if isinstance(attention_mask, torch.Tensor): - # offsets = (first_cache_position, max(first_cache_position - attention_chunk_size + 1, 0)) - # chunked_attention_mask = make_flex_block_causal_mask( - # attention_mask, attention_chunk_size, sequence_length, key_length, offsets=offsets - # ) - # attention_mask = make_flex_block_causal_mask( - # attention_mask, - # query_length=sequence_length, - # key_length=full_cache_length, - # offsets=(first_cache_position, 0), - # ) - # return attention_mask, chunked_attention_mask - # if isinstance(attention_mask, BlockMask): - # return attention_mask, chunked_attention_mask - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - dtype, device = input_tensor.dtype, input_tensor.device - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=max(full_cache_length, attention_chunk_size), - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - min_dtype=torch.finfo(dtype).min, - ) - if full_cache_length > attention_chunk_size: - start_idx = max(first_cache_position - attention_chunk_size + 1, 0) - end_idx = start_idx + key_length - chunked_attention_mask = self.create_chunked_attention_mask( - attention_chunk_size, - start=start_idx, # same offset as with flex - end=end_idx, - device=device, - ) - - ### Deci: we added this code to patch a bug in transformers - if attention_mask is None: - if past_key_values is not None: - raise NotImplementedError("We only support attention_mask=None is prefill") - attention_mask = torch.ones( - input_tensor.shape[0], input_tensor.shape[1], device=device, dtype=torch.long - ) - - local_attention_mask = attention_mask[:, start_idx:end_idx] # offset here as well - # It may be smaller than attention_chunk_size -> pad it - requires_padding = local_attention_mask.shape[-1] < attention_chunk_size - if requires_padding: - local_attention_mask = nn.functional.pad( - local_attention_mask, (0, attention_chunk_size - local_attention_mask.shape[-1]) - ) - # Depending on the padding, take the query tokens from the end or the cache_position - if not requires_padding: - chunked_attention_mask = chunked_attention_mask[None, None, -sequence_length:, :] - else: - chunked_attention_mask = chunked_attention_mask[None, None, cache_position, :] - - chunked_attention_mask = chunked_attention_mask.expand( - input_tensor.shape[0], -1, -1, -1 - ) - chunked_attention_mask = chunked_attention_mask * local_attention_mask[:, None, None, :] - if attn_implementation == "eager": - min_dtype = torch.finfo(dtype).min - chunked_attention_mask = torch.where( - chunked_attention_mask == 0, min_dtype, 0.0 - ).to(dtype) - - # print(f"{output_attentions=}") - - if ( - attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu"] - and attention_mask.ndim == 4 - and not output_attentions # Only unmask for 4d masks - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if attn_implementation == "sdpa" and chunked_attention_mask is not None: - chunked_attention_mask = chunked_attention_mask.bool() - causal_mask = causal_mask.bool() - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=first_cache_position, - is_training=self.training, - ): - causal_mask = None - return causal_mask, chunked_attention_mask - - def create_chunked_attention_mask( - self, attention_chunk_size: int, start: int, end: int, device: torch.device - ) -> torch.Tensor: - """ - Generate the following: - - 'What' : 0 ■ ⬚ ⬚ ⬚ ⬚ ⬚ | - '▁is' : 1 ■ ■ ⬚ ⬚ ⬚ ⬚ | - '▁ch' : 2 ■ ■ ■ ⬚ ⬚ ⬚ | - 'unked' : 3 ⬚ ⬚ ⬚ ■ ⬚ ⬚ | - '▁attention': 4 ⬚ ⬚ ⬚ ■ ■ ⬚ | - '?' : 5 ⬚ ⬚ ⬚ ■ ■ ■ | - - If the chunk size is 3. - This can just be appplied over the already created attention mask - """ - arange_vector = torch.arange(start, end, device=device) - block_pos = torch.abs( - arange_vector.unsqueeze(0) // attention_chunk_size - - arange_vector.unsqueeze(1) // attention_chunk_size - ) - token_pos = arange_vector.unsqueeze(0) - arange_vector.unsqueeze(1) - mask = (block_pos == 0) & (token_pos <= 0) - return mask.to(device) - - -class DeciLMMultiDecoderLayer(nn.Module): - def __init__(self, config: DeciLMConfig, layer_idx: int): - super().__init__() - self.config = config - block_config = config.block_configs[layer_idx] - assert block_config.parallel_blocks is not None - num_parallel_blocks = len(block_config.parallel_blocks) - self.parallel_blocks = nn.ModuleList( - [ - DeciLMDecoderLayer(config, (layer_idx, internal_block_idx)) - for internal_block_idx in range(num_parallel_blocks) - ] - ) - - def forward( - self, - hidden_states: torch.Tensor, - *args, - **kwargs, - ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]: - block_outputs = [block(hidden_states, *args, **kwargs) for block in self.parallel_blocks] - output_hidden_states = [ - out[0].to(hidden_states.device) - if isinstance(out, tuple) - else out.to(hidden_states.device) - for out in block_outputs - ] - output_hidden_states = torch.stack(output_hidden_states, dim=0).sum(dim=0) - output_hidden_states = ( - output_hidden_states - (len(self.parallel_blocks) - 1) * hidden_states - ) - - if self.config.block_return_only_hidden_states: - return output_hidden_states - - other_outputs = block_outputs[0][1:] - outputs = (output_hidden_states, *other_outputs) - return outputs - - -######################################################################## -# DeciLM-specific code -######################################################################## - - -def _find_multiple(n: int, k: int) -> int: - # DeciLM-specific code - if n % k == 0: - return n - return n + k - (n % k) - - class DeciLMMoe(nn.Module): """ Implementation of Mixture of Experts module for DeciLM. @@ -1680,64 +203,6 @@ def extra_repr(self) -> str: ) -class DeciLMLinearMLP(nn.Module): - # DeciLM-specific code - def __init__( - self, - config: DeciLMConfig, - ): - super().__init__() - self.linear_mlp = nn.Linear( - in_features=config.hidden_size, out_features=config.hidden_size, bias=False - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear_mlp.forward(x) - - -class DeciLMLinearAttention(nn.Module): - # DeciLM-specific code - def __init__( - self, - config: DeciLMConfig, - ): - super().__init__() - self.linear_attn = nn.Linear( - in_features=config.hidden_size, out_features=config.hidden_size, bias=False - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear_attn.forward(x) - - -def sparsity_backward_hook(*args, **kwargs): - raise NotImplementedError( - "No support for sparsity when training HF DeciLM (inference is ok though)" - ) - - -class DeciLMMambaMixer(nn.Module): - def __init__( - self, - config: DeciLMConfig, - mamba_config: MambaConfig, - ): - super().__init__() - self.mamba_mixer = MambaMixerMegatron( - d_model=config.hidden_size, - d_state=mamba_config.state_dim, - nheads=mamba_config.num_heads, - headdim=mamba_config.head_dim, - ngroups=mamba_config.num_groups, - ) - - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - x = x.permute([1, 0, 2]) # MambaMixerMegatron expects [Sequence, Batch, Embedding] - out = self.mamba_mixer(x) - out = out.permute([1, 0, 2]) # go back to [Batch, Sequence, Embedding] - return out - - class LMHead(nn.Linear): """ Special class to allow FSDP wrapping without affecting other Linear layers in the model. diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py deleted file mode 100644 index 72578006787..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_attn_mask_utils.py +++ /dev/null @@ -1,498 +0,0 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import torch - - -@dataclass -class AttentionMaskConverter: - """ - A utility attention mask class that allows one to: - - Create a causal 4d mask - - Create a causal 4d mask with slided window - - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, - key_value_length) that can be multiplied with attention scores - - Examples: - - ```python - >>> import torch - >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter - - >>> converter = AttentionMaskConverter(True) - >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32) - tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38], - [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]]) - ``` - - Parameters: - is_causal (`bool`): - Whether the attention mask should be a uni-directional (causal) or bi-directional mask. - - sliding_window (`int`, *optional*): - Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. - """ - - is_causal: bool - sliding_window: int - - def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): - self.is_causal = is_causal - self.sliding_window = sliding_window - - if self.sliding_window is not None and self.sliding_window <= 0: - raise ValueError( - f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`" - ) - - def to_causal_4d( - self, - batch_size: int, - query_length: int, - key_value_length: int, - dtype: torch.dtype, - device: Union[torch.device, "str"] = "cpu", - ) -> Optional[torch.Tensor]: - """ - Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative - bias to upper right hand triangular matrix (causal mask). - """ - if not self.is_causal: - raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") - - # If shape is not cached, create a new causal mask and cache it - input_shape = (batch_size, query_length) - past_key_values_length = key_value_length - query_length - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if input_shape[-1] > 1 or self.sliding_window is not None: - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - - return causal_4d_mask - - def to_4d( - self, - attention_mask_2d: torch.Tensor, - query_length: int, - dtype: torch.dtype, - key_value_length: Optional[int] = None, - ) -> torch.Tensor: - """ - Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, - key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is - causal, a causal mask will be added. - """ - input_shape = (attention_mask_2d.shape[0], query_length) - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - causal_4d_mask = None - if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: - if key_value_length is None: - raise ValueError( - "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." - ) - - past_key_values_length = key_value_length - query_length - causal_4d_mask = self._make_causal_mask( - input_shape, - dtype, - device=attention_mask_2d.device, - past_key_values_length=past_key_values_length, - sliding_window=self.sliding_window, - ) - elif self.sliding_window is not None: - raise NotImplementedError("Sliding window is currently only implemented for causal masking") - - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( - attention_mask_2d.device - ) - - if causal_4d_mask is not None: - expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min) - - # expanded_attn_mask + causal_4d_mask can cause some overflow - expanded_4d_mask = expanded_attn_mask - - return expanded_4d_mask - - @staticmethod - def _make_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, - ): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - - # add lower triangular sliding window mask if necessary - if sliding_window is not None: - diagonal = past_key_values_length - sliding_window - 1 - - context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal) - mask.masked_fill_(context_mask, torch.finfo(dtype).min) - - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - @staticmethod - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - @staticmethod - def _unmask_unattended( - expanded_mask: torch.FloatTensor, - min_dtype: float, - ): - # fmt: off - """ - Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when - using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - Details: https://github.com/pytorch/pytorch/issues/110213 - - `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len]. - `attention_mask` is [bsz, src_seq_len]. - - The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. - - For example, if `expanded_mask` is (e.g. here left-padding case) - ``` - [[[[0, 0, 0], - [0, 0, 0], - [0, 0, 1]]], - [[[1, 0, 0], - [1, 1, 0], - [1, 1, 1]]], - [[[0, 0, 0], - [0, 1, 0], - [0, 1, 1]]]] - ``` - then the modified `expanded_mask` will be - ``` - [[[[1, 1, 1], <-- modified - [1, 1, 1], <-- modified - [0, 0, 1]]], - [[[1, 0, 0], - [1, 1, 0], - [1, 1, 1]]], - [[[1, 1, 1], <-- modified - [0, 1, 0], - [0, 1, 1]]]] - ``` - """ - # fmt: on - if expanded_mask.dtype == torch.bool: - raise ValueError( - "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." - ) - - return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) - - @staticmethod - def _ignore_causal_mask_sdpa( - attention_mask: Optional[torch.Tensor], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, - is_training: bool = False, - ) -> bool: - """ - Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. - - In case no token is masked in the `attention_mask` argument, if `query_length == 1` or - `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks, - allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). - """ - - _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1] - key_value_length = query_length + past_key_values_length - - is_tracing = ( - torch.jit.is_tracing() - or isinstance(inputs_embeds, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - - ignore_causal_mask = False - - if attention_mask is None: - # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or - # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). - # Thus, we only set `ignore_causal_mask = True` if the model is set to training. - # - # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor"). - if ( - (is_training or not is_tracing) - and (query_length == 1 or key_value_length == query_length) - and (sliding_window is None or key_value_length < sliding_window) - ): - ignore_causal_mask = True - elif sliding_window is None or key_value_length < sliding_window: - if len(attention_mask.shape) == 4: - return False - elif (is_training or not is_tracing) and torch.all(attention_mask == 1): - if query_length == 1 or key_value_length == query_length: - # For query_length == 1, causal attention and bi-directional attention are the same. - ignore_causal_mask = True - - # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore the attention mask, as SDPA causal mask generation - # may be wrong. We will set `is_causal=False` in SDPA and rely on Transformers attention_mask instead, hence not setting it to None here. - # Reference: https://github.com/pytorch/pytorch/issues/108108 - # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3. - - return ignore_causal_mask - - -def _prepare_4d_causal_attention_mask( - attention_mask: Optional[torch.Tensor], - input_shape: Union[torch.Size, Tuple, List], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)` - - Args: - attention_mask (`torch.Tensor` or `None`): - A 2D attention mask of shape `(batch_size, key_value_length)` - input_shape (`tuple(int)` or `list(int)` or `torch.Size`): - The input shape should be a tuple that defines `(batch_size, query_length)`. - inputs_embeds (`torch.Tensor`): - The embedded inputs as a torch Tensor. - past_key_values_length (`int`): - The length of the key value cache. - sliding_window (`int`, *optional*): - If the model uses windowed attention, a sliding window should be passed. - """ - attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = input_shape[-1] + past_key_values_length - - # 4d mask is passed through the layers - if attention_mask is not None and len(attention_mask.shape) == 2: - attention_mask = attn_mask_converter.to_4d( - attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype - ) - elif attention_mask is not None and len(attention_mask.shape) == 4: - expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) - if tuple(attention_mask.shape) != expected_shape: - raise ValueError( - f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." - ) - else: - # if the 4D mask has correct shape - invert it and fill with negative infinity - inverted_mask = 1.0 - attention_mask - attention_mask = inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min - ) - else: - attention_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - - return attention_mask - - -# Adapted from _prepare_4d_causal_attention_mask -def _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask: Optional[torch.Tensor], - input_shape: Union[torch.Size, Tuple, List], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: Optional[int] = None, -): - """ - Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`. - - In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and - `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks, - allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed). - """ - attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = input_shape[-1] + past_key_values_length - - # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` - # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(inputs_embeds, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - - ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - - if ignore_causal_mask: - expanded_4d_mask = None - elif attention_mask is None: - expanded_4d_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device - ) - else: - if attention_mask.dim() == 4: - # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing - if attention_mask.max() != 0: - raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") - expanded_4d_mask = attention_mask - else: - expanded_4d_mask = attn_mask_converter.to_4d( - attention_mask, - input_shape[-1], - dtype=inputs_embeds.dtype, - key_value_length=key_value_length, - ) - - # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - if not is_tracing and expanded_4d_mask.device.type == "cuda": - expanded_4d_mask = AttentionMaskConverter._unmask_unattended( - expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min - ) - - return expanded_4d_mask - - -def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)` - - Args: - mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` - dtype (`torch.dtype`): - The torch dtype the created mask shall have. - tgt_len (`int`): - The target length or query length the created mask shall have. - """ - return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) - - -def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)` - - Args: - mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` - dtype (`torch.dtype`): - The torch dtype the created mask shall have. - tgt_len (`int`): - The target length or query length the created mask shall have. - """ - _, key_value_length = mask.shape - tgt_len = tgt_len if tgt_len is not None else key_value_length - - is_tracing = ( - torch.jit.is_tracing() - or isinstance(mask, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - - # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows. - if not is_tracing and torch.all(mask == 1): - return None - else: - return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) - - -def _create_4d_causal_attention_mask( - input_shape: Union[torch.Size, Tuple, List], - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: Optional[int] = None, -) -> Optional[torch.Tensor]: - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` - - Args: - input_shape (`tuple(int)` or `list(int)` or `torch.Size`): - The input shape should be a tuple that defines `(batch_size, query_length)`. - dtype (`torch.dtype`): - The torch dtype the created mask shall have. - device (`int`): - The torch device the created mask shall have. - sliding_window (`int`, *optional*): - If the model uses windowed attention, a sliding window should be passed. - """ - attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window) - - key_value_length = past_key_values_length + input_shape[-1] - attention_mask = attn_mask_converter.to_causal_4d( - input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device - ) - - return attention_mask diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py deleted file mode 100644 index 9e9fb46ca4d..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py +++ /dev/null @@ -1,363 +0,0 @@ -# coding=utf-8 -# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -import inspect -import os -from typing import Optional, Tuple, Union - - -import torch -import torch.nn.functional as F - -from functools import lru_cache -import importlib.metadata -import importlib.util -from packaging import version - -from transformers.utils import is_flash_attn_2_available - - -if is_flash_attn_2_available(): - try: - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - from flash_attn import flash_attn_func, flash_attn_varlen_func - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) - except ImportError: - raise "Unable to import flash_attn" - - -def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: - # Check if the package spec exists and grab its version to avoid importing a local directory - package_exists = importlib.util.find_spec(pkg_name) is not None - package_version = "N/A" - if package_exists: - try: - # Primary method to get the package version - package_version = importlib.metadata.version(pkg_name) - except importlib.metadata.PackageNotFoundError: - # Fallback method: Only for "torch" and versions containing "dev" - if pkg_name == "torch": - try: - package = importlib.import_module(pkg_name) - temp_version = getattr(package, "__version__", "N/A") - # Check if the version contains "dev" - if "dev" in temp_version: - package_version = temp_version - package_exists = True - else: - package_exists = False - except ImportError: - # If the package can't be imported, it's not available - package_exists = False - else: - # For packages other than "torch", don't attempt the fallback and set as not available - package_exists = False - if return_version: - return package_exists, package_version - else: - return package_exists - - -@lru_cache() -def is_flash_attn_greater_or_equal(library_version: str): - if not _is_package_available("flash_attn"): - return False - - return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) - - -def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: - """ - Retrieves indexing data required to repad unpadded (ragged) tensors. - - Arguments: - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - indices (`torch.Tensor`): - The indices of non-masked tokens from the flattened input sequence. - cu_seqlens (`torch.Tensor`): - The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - max_seqlen_in_batch (`int`): - Maximum sequence length in batch. - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _upad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, -): - """ - Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. - - This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary - tensors for query, key, value tensors. - - Arguments: - query_layer (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - attention_mask (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - query_length (`int`): - Target length. - - Return: - query_layer (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key_layer (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value_layer (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -def prepare_fa2_from_position_ids(query, key, value, position_ids): - """ - This function returns necessary arguments to call `flash_attn_varlen_func`. - All three query, key, value states will be flattened. - Cummulative lengths of each examples in the batch will be extracted from position_ids. - - NOTE: ideally cummulative lengths should be prepared at the data collator stage - - Arguments: - query (`torch.Tensor`): - Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). - position_ids (`torch.Tensor`): - Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. - - Return: - query (`torch.Tensor`): - Query state without padding. Shape: (total_target_length, num_heads, head_dim). - key (`torch.Tensor`): - Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - value (`torch.Tensor`): - Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). - indices_q (`torch.Tensor`): - The indices of non-masked tokens from the flattened input target sequence. - (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): - The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). - (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): - Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). - """ - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - position_ids = position_ids.flatten() - indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) - - cu_seq_lens = torch.cat( - ( - indices_q[position_ids == 0], - torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), - ) - ) - - max_length = position_ids.max() + 1 - - return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) - - -def _flash_attention_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask: torch.Tensor, - query_length: int, - is_causal: bool, - dropout: float = 0.0, - position_ids: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - sliding_window: Optional[int] = None, - use_top_left_mask: bool = False, - softcap: Optional[float] = None, - deterministic: bool = None, -): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`float`): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - use_top_left_mask (`bool`, defaults to `False`): - flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. - softcap (`float`, *optional*): - Softcap for the attention logits, used e.g. in gemma2. - deterministic (`bool`, *optional*): - Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. - """ - if not use_top_left_mask: - causal = is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. - causal = is_causal and query_length != 1 - - # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). - use_sliding_windows = ( - _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window - ) - flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} - - if is_flash_attn_greater_or_equal("2.4.1"): - if deterministic is None: - deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" - flash_kwargs["deterministic"] = deterministic - - if softcap is not None: - flash_kwargs["softcap"] = softcap - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - - # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing - # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. - # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): - batch_size = query_states.size(0) - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( - query_states, key_states, value_states, position_ids - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - **flash_kwargs, - ) - - attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) - - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs - ) - - return attn_output diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py deleted file mode 100644 index aa9f07b8797..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__modeling_outputs.py +++ /dev/null @@ -1,1768 +0,0 @@ -# Copyright 2020 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch - -from transformers.utils import ModelOutput - - -@dataclass -class BaseModelOutput(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithNoAttention(ModelOutput): - """ - Base class for model's outputs, with potential hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPooling(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPoolingAndNoAttention(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state after a pooling operation on the spatial dimensions. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPast(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithCrossAttentions(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class MoECausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden - states terms, to train a MoE model. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - z_loss for the sparse modules. - aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - aux_loss for the sparse modules. - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse - modules. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - z_loss: torch.FloatTensor = None - aux_loss: torch.FloatTensor = None - router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoEModelOutput(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary - loss and the z_loss for Mixture of Experts models. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_probs: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoeModelOutputWithPast(ModelOutput): - """ - Base class for model's outputs, with potential hidden states and attentions. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary - loss for Mixture of Experts models. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoeCausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) with mixture of experts outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - - aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): - aux_loss for the sparse modules. - - router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary - loss for Mixture of Experts models. - - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - aux_loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class MoEModelOutputWithPastAndCrossAttentions(ModelOutput): - """ - Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as - Mixture of Expert's router hidden states terms, to train a MoE model. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if - `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, - encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if - `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` - input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary - loss and the z_loss for Mixture of Experts models. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - router_probs: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class Seq2SeqModelOutput(ModelOutput): - """ - Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential - decoding. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqMoEModelOutput(ModelOutput): - """ - Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential - decoding. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse - modules. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class CausalLMOutput(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class CausalLMOutputWithPast(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class CausalLMOutputWithCrossAttentions(ModelOutput): - """ - Base class for causal language model (or autoregressive) outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss (for next-token prediction). - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Cross attentions weights after the attention softmax, used to compute the weighted average in the - cross-attention heads. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `torch.FloatTensor` tuples of length `config.n_layers`, with each tuple containing the cached key, - value states of the self-attention and the cross-attention layers if model is used in encoder-decoder - setting. Only relevant if `config.is_decoder = True`. - - Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class SequenceClassifierOutputWithPast(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) - - Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see - `past_key_values` input) to speed up sequential decoding. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class MaskedLMOutput(ModelOutput): - """ - Base class for masked language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqLMOutput(ModelOutput): - """ - Base class for sequence-to-sequence language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqMoEOutput(ModelOutput): - """ - Base class for sequence-to-sequence language models outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Language modeling loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. - - Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts - models. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - encoder_z_loss: torch.FloatTensor = None - decoder_z_loss: torch.FloatTensor = None - encoder_aux_loss: torch.FloatTensor = None - decoder_aux_loss: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_router_logits: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class NextSentencePredictorOutput(ModelOutput): - """ - Base class for outputs of models predicting if two sentences are consecutive or not. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided): - Next sequence prediction (classification) loss. - logits (`torch.FloatTensor` of shape `(batch_size, 2)`): - Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation - before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class SequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sentence classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqSequenceClassifierOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence sentence classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class MultipleChoiceModelOutput(ModelOutput): - """ - Base class for outputs of multiple choice models. - - Args: - loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided): - Classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): - *num_choices* is the second dimension of the input tensors. (see *input_ids* above). - - Classification scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class TokenClassifierOutput(ModelOutput): - """ - Base class for outputs of token classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) : - Classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`): - Classification scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class QuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of question answering models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - start_logits: torch.FloatTensor = None - end_logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): - """ - Base class for outputs of sequence-to-sequence question answering models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. - start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-start scores (before SoftMax). - end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - Span-end scores (before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - start_logits: torch.FloatTensor = None - end_logits: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class SemanticSegmenterOutput(ModelOutput): - """ - Base class for outputs of semantic segmentation models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`): - Classification scores for each pixel. - - - - The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is - to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the - original image size as post-processing. You should always check your logits shape and resize as needed. - - - - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class ImageClassifierOutput(ModelOutput): - """ - Base class for outputs of image classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states - (also called feature maps) of the model at the output of each stage. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class ImageClassifierOutputWithNoAttention(ModelOutput): - """ - Base class for outputs of image classification models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): - Classification (or regression if config.num_labels==1) scores (before SoftMax). - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also - called feature maps) of the model at the output of each stage. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class DepthEstimatorOutput(ModelOutput): - """ - Base class for outputs of depth estimation models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification (or regression if config.num_labels==1) loss. - predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`): - Predicted depth for each pixel. - - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - predicted_depth: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class ImageSuperResolutionOutput(ModelOutput): - """ - Base class for outputs of image super resolution models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Reconstruction loss. - reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Reconstructed images, possibly upscaled. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states - (also called feature maps) of the model at the output of each stage. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - reconstruction: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Wav2Vec2BaseModelOutput(ModelOutput): - """ - Base class for models that have been trained with the Wav2Vec2 loss objective. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): - Sequence of extracted feature vectors of the last convolutional layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - extract_features: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class XVectorOutput(ModelOutput): - """ - Output type of [`Wav2Vec2ForXVector`]. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Classification loss. - logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): - Classification hidden states before AMSoftmax. - embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`): - Utterance embeddings used for vector similarity-based retrieval. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - loss: Optional[torch.FloatTensor] = None - logits: torch.FloatTensor = None - embeddings: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BackboneOutput(ModelOutput): - """ - Base class for outputs of backbones. - - Args: - feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`): - Feature maps of the stages. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`, - depending on the backbone. - - Hidden-states of the model at the output of each stage plus the initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. Only applicable if the backbone uses attention. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - feature_maps: Tuple[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class BaseModelOutputWithPoolingAndProjection(ModelOutput): - """ - Base class for model's outputs that also contains a pooling of the last hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`. - - Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - projection_state: Optional[Tuple[torch.FloatTensor]] = None - - -@dataclass -class Seq2SeqSpectrogramOutput(ModelOutput): - """ - Base class for sequence-to-sequence spectrogram outputs. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Spectrogram generation loss. - spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): - The predicted spectrogram. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - spectrogram: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - -@dataclass -class Seq2SeqTSModelOutput(ModelOutput): - """ - Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up - sequential decoding. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the decoder of the model. - - If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, - hidden_size)` is output. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Shift values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to shift back to the original magnitude. - scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Scaling values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to rescale back to the original magnitude. - static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): - Static features of each time series' in a batch which are copied to the covariates at inference time. - """ - - last_hidden_state: torch.FloatTensor = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - loc: Optional[torch.FloatTensor] = None - scale: Optional[torch.FloatTensor] = None - static_features: Optional[torch.FloatTensor] = None - - -@dataclass -class Seq2SeqTSPredictionOutput(ModelOutput): - """ - Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the - chosen distribution. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided): - Distributional loss. - params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`): - Parameters of the chosen distribution. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. - decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the - weighted average in the cross-attention heads. - encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Sequence of hidden-states at the output of the last layer of the encoder of the model. - encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. - encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the - self-attention heads. - loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Shift values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to shift back to the original magnitude. - scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*): - Scaling values of each time series' context window which is used to give the model inputs of the same - magnitude and then used to rescale back to the original magnitude. - static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*): - Static features of each time series' in a batch which are copied to the covariates at inference time. - """ - - loss: Optional[torch.FloatTensor] = None - params: Optional[Tuple[torch.FloatTensor]] = None - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_last_hidden_state: Optional[torch.FloatTensor] = None - encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - loc: Optional[torch.FloatTensor] = None - scale: Optional[torch.FloatTensor] = None - static_features: Optional[torch.FloatTensor] = None - - -@dataclass -class SampleTSPredictionOutput(ModelOutput): - """ - Base class for time series model's predictions outputs that contains the sampled values from the chosen - distribution. - - Args: - sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`): - Sampled values from the chosen distribution. - """ - - sequences: torch.FloatTensor = None - - -@dataclass -class MaskedImageModelingOutput(ModelOutput): - """ - Base class for outputs of masked image completion / in-painting models. - - Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): - Reconstruction loss. - reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Reconstructed / completed images. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or - when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + - one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states - (also called feature maps) of the model at the output of each stage. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when - `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, - sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in - the self-attention heads. - """ - - loss: Optional[torch.FloatTensor] = None - reconstruction: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None - attentions: Optional[Tuple[torch.FloatTensor, ...]] = None - - @property - def logits(self): - warnings.warn( - "logits attribute is deprecated and will be removed in version 5 of Transformers." - " Please use the reconstruction attribute to retrieve the final output instead.", - FutureWarning, - ) - return self.reconstruction diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py deleted file mode 100644 index a1b413b0e08..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_44_2__pytorch_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from torch import nn - -ALL_LAYERNORM_LAYERS = [nn.LayerNorm] \ No newline at end of file diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py deleted file mode 100644 index 3dac4a51c6e..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__cache_utils.py +++ /dev/null @@ -1,2535 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -import copy -import importlib.metadata -import json -import os -from collections.abc import Iterable -from dataclasses import dataclass -from typing import Any - -import torch -from packaging import version -from transformers.configuration_utils import PretrainedConfig -from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6 -from transformers.utils import ( - is_hqq_available, - is_optimum_quanto_available, - is_torch_greater_or_equal, - logging, -) - -if is_hqq_available(): - from hqq.core.quantize import Quantizer as HQQQuantizer - -logger = logging.get_logger(__name__) - - -class Cache: - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - is_compileable = False - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: int | None = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx].numel(): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) - if self.value_cache[layer_idx].numel(): - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( - 0, beam_idx.to(device) - ) - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None - - -@dataclass -class CacheConfig: - """ - Base class for cache configs - """ - - cache_implementation: None - - @classmethod - def from_dict(cls, config_dict, **kwargs): - """ - Constructs a CacheConfig instance from a dictionary of parameters. - Args: - config_dict (Dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - CacheConfig: Instance of CacheConfig constructed from the dictionary. - """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file - def to_json_file(self, json_file_path: str | os.PathLike): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default - `QuantizationConfig()` is serialized to JSON file. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - return copy.deepcopy(self.__dict__) - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__ - def __iter__(self): - """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. - Returns: - str: JSON formatted string representing the configuration instance. - """ - return json.dumps(self.__dict__, indent=2) + "\n" - - # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update - def update(self, **kwargs): - """ - Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, - returning all the unused kwargs. - - Args: - kwargs (`Dict[str, Any]`): - Dictionary of attributes to tentatively update this class. - - Returns: - `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. - """ - to_remove = [] - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - to_remove.append(key) - - # Remove all the attributes that were updated, without modifying the input dict - unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} - return unused_kwargs - - -@dataclass -class QuantizedCacheConfig(CacheConfig): - """ - Configuration class for quantized cache settings. - - Attributes: - backend (`str`, *optional*, defaults to `"quanto"`): - Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`] - nbits (`Optional[int]`, *optional*, defaults to 4): - Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2. - axis_key (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - axis_value (`int`, *optional*, defaults to 0): - Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend. - q_group_size (`Optional[int]`, *optional*, defaults to 64): - Size of the quantization group, should be a divisor of the model's hidden dimension. - Defaults to 64. - residual_length (`Optional[int]`, *optional*, defaults to 128): - Length of the residual cache which will always be stored in original precision. - Defaults to 128. - compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization. - device (`str`, *optional*, defaults to `"cpu"`): - Device on which to perform computations, should be same as the model's device. - """ - - def __init__( - self, - backend: str = "quanto", - nbits: int | None = 4, - axis_key: int | None = 0, - axis_value: int | None = 0, - q_group_size: int | None = 64, - residual_length: int | None = 128, - compute_dtype: torch.dtype | None = torch.float16, - device: str | None = "cpu", - ): - self.backend = backend - self.nbits = nbits - self.axis_key = axis_key - self.axis_value = axis_value - self.q_group_size = q_group_size - self.residual_length = residual_length - self.compute_dtype = compute_dtype - self.device = device - - def validate(self): - """Validates if the arguments passed are correct""" - - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - # Check that the values are reasonable in general (nbits, axis) - # Later in QuantizedCache init we check if they are supported for that particular backend - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - incorrect_arg_msg.format( - key="nbits", - correct_value="2 or 4 or 8", - found_value=self.nbits, - ), - ) - if self.q_group_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="q_group_size", - correct_value="a positive integer", - found_value=self.q_group_size, - ), - ) - if self.residual_length < 0: - raise ValueError( - incorrect_arg_msg.format( - key="residual_length", - correct_value="a positive integer", - found_value=self.residual_length, - ), - ) - - if self.axis_key not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_key", - correct_value="`1` or `0`, `-1`", - found_value=self.axis_key, - ), - ) - - if self.axis_value not in [0, 1, -1]: - raise ValueError( - incorrect_arg_msg.format( - key="axis_value", - correct_value="`1` or `0` or `-1`", - found_value=self.axis_value, - ), - ) - - -@dataclass -class StaticCacheConfig(CacheConfig): - """ - Configuration class for static cache settings. - """ - - cache_implementation = "static" - - def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): - self.batch_size = batch_size - self.max_cache_len = max_cache_len - self.device = device - - def validate(self): - """Validates if the arguments passed are correct""" - - incorrect_arg_msg = ( - "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - - if self.batch_size <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="batch_size", - correct_value="> 0", - found_value=self.batch_size, - ), - ) - - if self.max_cache_len <= 0: - raise ValueError( - incorrect_arg_msg.format( - key="max_cache_len", - correct_value="> 0", - found_value=self.max_cache_len, - ), - ) - - -class DynamicCache(Cache): - """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = DynamicCache() - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - DynamicCache() - ``` - """ - - def __init__(self, _distributed_cache_data: Iterable = None) -> None: - super().__init__() - self._seen_tokens = ( - 0 # Used in `generate` to keep tally of how many tokens the cache has seen - ) - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - # `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121 - # and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the - # iterable contains the key and value states for a layer gathered across replicas by torch.distributed - # (shape=[global batch size, num_heads, seq_len, head_dim]). - # WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break - # compatibility. The name of the argument doesn't matter. - if _distributed_cache_data is not None: - for key_states, value_states in _distributed_cache_data: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return (self.key_cache[layer_idx], self.value_cache[layer_idx]) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def __iter__(self): - """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over - keys and values - """ - for layer_idx in range(len(self)): - yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.key_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. - - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if key_states is not None: - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[ - layer_idx - ].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx], key_states], dim=-2 - ) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - is_empty_layer = ( - len(self.key_cache) == 0 # no cache in any layer - or len(self.key_cache) - <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it - or not self.key_cache[layer_idx].numel() # the layer has no cache - ) - layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 - return layer_seq_length - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length.""" - return None - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for - backward compatibility.""" - legacy_cache = () - for layer_idx in range(len(self)): - legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for - backward compatibility.""" - cache = cls() - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def crop(self, max_length: int): - """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be - negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative - if max_length < 0: - max_length = self.get_seq_length() - abs(max_length) - - if self.get_seq_length() <= max_length: - return - - self._seen_tokens = max_length - for idx in range(len(self.key_cache)): - if self.key_cache[idx].numel(): - self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] - self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] - - def batch_split(self, full_batch_size: int, split_size: int) -> list["DynamicCache"]: - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - out = [] - for i in range(0, full_batch_size, split_size): - current_split = DynamicCache() - current_split._seen_tokens = self._seen_tokens - current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] - current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] - out.append(current_split) - return out - - @classmethod - def from_batch_splits(cls, splits: list["DynamicCache"]) -> "DynamicCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - cache = cls() - for idx in range(len(splits[0])): - key_cache = [ - current.key_cache[idx] for current in splits if current.key_cache[idx].numel() - ] - value_cache = [ - current.value_cache[idx] for current in splits if current.value_cache[idx].numel() - ] - if key_cache != []: - layer_keys = torch.cat(key_cache, dim=0) - layer_values = torch.cat(value_cache, dim=0) - cache.update(layer_keys, layer_values, idx) - return cache - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) - self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave( - repeats, dim=0 - ) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - for layer_idx in range(len(self)): - self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] - self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] - - -# Utilities for `DynamicCache` <> torch.export support -def _flatten_dynamic_cache( - dynamic_cache: DynamicCache, -): - """Flattens DynamicCache into flat list of tensors for `torch.export.export` to consume""" - if not isinstance(dynamic_cache, DynamicCache): - raise RuntimeError("This pytree flattening function should only be applied to DynamicCache") - - if not is_torch_greater_or_equal_than_2_6: - logger.warning_once( - "DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions." - ) - - # NOTE it seems _seen_tokens is deprecated, so probably doesn't need tracking - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten(dictionary) - - -def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache): - dictionary = { - "key_cache": getattr(dynamic_cache, "key_cache"), - "value_cache": getattr(dynamic_cache, "value_cache"), - } - return torch.utils._pytree._dict_flatten_with_keys(dictionary) - - -def _unflatten_dynamic_cache( - values, - context: torch.utils._pytree.Context, -): - dictionary = torch.utils._pytree._dict_unflatten(values, context) - cache = DynamicCache() - for k, v in dictionary.items(): - setattr(cache, k, v) - return cache - - -def _flatten_dynamic_cache_for_fx(cache, spec): - dictionary = { - "key_cache": getattr(cache, "key_cache"), - "value_cache": getattr(cache, "value_cache"), - } - return torch.utils._pytree.tree_flatten(dictionary)[0] - - -if is_torch_greater_or_equal("2.3"): - torch.utils._pytree.register_pytree_node( - DynamicCache, - _flatten_dynamic_cache, - _unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=_flatten_with_keys_dynamic_cache, - ) - # TODO (tmanlaibaatar) This won't be needed in torch 2.7. - torch.fx._pytree.register_pytree_flatten_spec(DynamicCache, _flatten_dynamic_cache_for_fx) - - -class OffloadedCache(DynamicCache): - """ - A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory. - Useful for generating from models with very long context. - - In addition to the default accelerator stream, where all forward() computations happen, - this class uses another stream, the prefetch stream, which it creates itself. - Since scheduling of operations on separate streams happens independently, this class uses - the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing. - The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to - ensure the eviction is scheduled after all computations on that cache are finished. - """ - - def __init__(self) -> None: - if not ( - torch.cuda.is_available() - or (is_torch_greater_or_equal("2.7", accept_dev=True) and torch.xpu.is_available()) - ): - raise RuntimeError( - "OffloadedCache can only be used with a GPU" - + (" or XPU" if is_torch_greater_or_equal("2.7", accept_dev=True) else "") - ) - - super().__init__() - self.original_device = [] - self.prefetch_stream = None - self.prefetch_stream = ( - torch.Stream() - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.Stream() - ) - self.beam_idx = None # used to delay beam search operations - - def prefetch_layer(self, layer_idx: int): - "Starts prefetching the next layer cache" - if layer_idx < len(self): - with ( - self.prefetch_stream - if is_torch_greater_or_equal("2.7", accept_dev=True) - else torch.cuda.stream(self.prefetch_stream) - ): - # Prefetch next layer tensors to GPU - device = self.original_device[layer_idx] - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to( - device, non_blocking=True - ) - - def evict_previous_layer(self, layer_idx: int): - "Moves the previous layer cache to the CPU" - if len(self) > 2: - # We do it on the default stream so it occurs after all earlier computations on these tensors are done - prev_layer_idx = (layer_idx - 1) % len(self) - self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to( - "cpu", non_blocking=True - ) - self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to( - "cpu", non_blocking=True - ) - - def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: - "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer." - if layer_idx < len(self): - # Evict the previous layer if necessary - if is_torch_greater_or_equal("2.7", accept_dev=True): - torch.accelerator.current_stream().synchronize() - else: - torch.cuda.current_stream().synchronize() - self.evict_previous_layer(layer_idx) - # Load current layer cache to its original device if not already there - original_device = self.original_device[layer_idx] - self.prefetch_stream.synchronize() - key_tensor = self.key_cache[layer_idx] - value_tensor = self.value_cache[layer_idx] - # Now deal with beam search ops which were delayed - if self.beam_idx is not None: - self.beam_idx = self.beam_idx.to(original_device) - key_tensor = key_tensor.index_select(0, self.beam_idx) - value_tensor = value_tensor.index_select(0, self.beam_idx) - # Prefetch the next layer - self.prefetch_layer((layer_idx + 1) % len(self)) - return (key_tensor, value_tensor) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Saves the beam indices and reorders the cache when the tensor is back to its device.""" - # We delay this operation until the tensors are back to their original - # device because performing torch.index_select on the CPU is very slow - del self.beam_idx - self.beam_idx = beam_idx.clone() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`. - Return: - A tuple containing the updated key and value states. - """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache - if len(self.key_cache) < layer_idx: - raise ValueError( - "OffloadedCache does not support model usage where layers are skipped. Use DynamicCache." - ) - elif len(self.key_cache) == layer_idx: - self.key_cache.append(key_states) - self.value_cache.append(value_states) - self.original_device.append(key_states.device) - self.evict_previous_layer(layer_idx) - else: - key_tensor, value_tensor = self[layer_idx] - self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError - # if a method is not supposed to be supported in a subclass we should set it to None - from_legacy_cache = None - - to_legacy_cache = None - - -class QuantizedCache(DynamicCache): - """ - A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750). - It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization. - - The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the - original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The - quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper. - - It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and - Value in original precision states as a list of tensors, one for each layer. The size of each tensor - is `[batch_size, num_heads, seq_len - residual_length, head_dim]` - """ - - def __init__(self, cache_config: QuantizedCacheConfig) -> None: - super().__init__() - self._quantized_key_cache: list[torch.Tensor] = [] - self._quantized_value_cache: list[torch.Tensor] = [] - - self.nbits = cache_config.nbits - self.residual_length = cache_config.residual_length - self.q_group_size = cache_config.q_group_size - self.axis_key = cache_config.axis_key - self.axis_value = cache_config.axis_value - self.compute_dtype = cache_config.compute_dtype - self.device = cache_config.device - - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - if len(self.key_cache) < layer_idx: - raise ValueError( - "QuantizedCache does not support model usage where layers are skipped. Use DynamicCache." - ) - elif len(self.key_cache) == layer_idx: - self._quantized_key_cache.append( - self._quantize(key_states.contiguous(), axis=self.axis_key) - ) - self._quantized_value_cache.append( - self._quantize(value_states.contiguous(), axis=self.axis_value) - ) - self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) - self.value_cache.append( - torch.zeros(0, dtype=key_states.dtype, device=key_states.device) - ) - keys_to_return, values_to_return = key_states, value_states - else: - dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) - dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) - keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] - values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] - - keys_to_return = torch.cat(keys_to_return, dim=-2) - values_to_return = torch.cat(values_to_return, dim=-2) - if ( - self.key_cache[layer_idx].dim() == 4 - and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length - ): - self._quantized_key_cache[layer_idx] = self._quantize( - keys_to_return.contiguous(), axis=self.axis_key - ) - self._quantized_value_cache[layer_idx] = self._quantize( - values_to_return.contiguous(), axis=self.axis_value - ) - self.key_cache[layer_idx] = torch.zeros( - 0, dtype=key_states.dtype, device=key_states.device - ) - self.value_cache[layer_idx] = torch.zeros( - 0, dtype=key_states.dtype, device=key_states.device - ) - else: - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx], key_states], dim=-2 - ) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - return keys_to_return, values_to_return - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - if len(self.key_cache) <= layer_idx: - return 0 - # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is - # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to verify attn_weight shape in some models - return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 - - def _quantize(self, tensor, axis): - """Quantizes a key/value using a defined quantization method.""" - raise NotImplementedError("Make sure to implement `_quantize` in a subclass.") - - def _dequantize(self, q_tensor): - """Dequantizes back the tensor that was quantized by `self._quantize()`""" - raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.") - - -class QuantoQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install quanto first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4) - >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - QuantoQuantizedCache() - ``` - """ - - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - - if is_optimum_quanto_available(): - optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) - if optimum_quanto_version <= version.parse("0.2.5"): - raise ImportError( - f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." - ) - from optimum.quanto import MaxOptimizer, qint2, qint4 - - if self.nbits not in [2, 4]: - raise ValueError( - f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}" - ) - - if self.axis_key not in [0, -1]: - raise ValueError( - f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}" - ) - - if self.axis_value not in [0, -1]: - raise ValueError( - f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" - ) - - self.qtype = qint4 if self.nbits == 4 else qint2 - self.optimizer = ( - MaxOptimizer() - ) # hardcode as it's the only one for per-channel quantization - - def _quantize(self, tensor, axis): - # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore - if is_optimum_quanto_available(): - from optimum.quanto import quantize_weight - - scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) - qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) - return qtensor - - def _dequantize(self, qtensor): - return qtensor.dequantize() - - -class HQQQuantizedCache(QuantizedCache): - """ - Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes. - - Parameters: - cache_config (`QuantizedCacheConfig`): - A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. - - Example: - - ```python - >>> # Run pip install hqq first if you don't have it yet - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1) - >>> past_key_values = HQQQuantizedCache(cache_config=cache_config) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HQQQuantizedCache() - ``` - """ - - def __init__(self, cache_config: CacheConfig) -> None: - super().__init__(cache_config) - if self.nbits not in [1, 2, 3, 4, 8]: - raise ValueError( - f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" - ) - - if self.axis_key not in [0, 1]: - raise ValueError( - f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}" - ) - - if self.axis_value not in [0, 1]: - raise ValueError( - f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}" - ) - - self.quantizer = HQQQuantizer - - def _quantize(self, tensor, axis): - qtensor, meta = self.quantizer.quantize( - tensor, - axis=axis, - device=self.device, - compute_dtype=self.compute_dtype, - nbits=self.nbits, - group_size=self.q_group_size, - ) - meta["compute_dtype"] = self.compute_dtype - self.quantizer.cuda( - qtensor, meta=meta, device=self.device - ) # Move to device and cast to dtype - return qtensor, meta - - def _dequantize(self, qtensor): - quant_tensor, meta = qtensor - tensor = self.quantizer.dequantize(quant_tensor, meta) - return tensor - - -class SinkCache(Cache): - """ - A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to - generate beyond the length of its context window, without losing fluency in the conversation. As it discards past - tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. - - It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]`. - - Parameters: - window_length (`int`): - The length of the context window. - num_sink_tokens (`int`): - The number of sink tokens. See the original paper for more information. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache - - >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SinkCache() - ``` - """ - - is_sliding = True - - def __init__(self, window_length: int, num_sink_tokens: int) -> None: - super().__init__() - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - self.window_length = window_length - self.num_sink_tokens = num_sink_tokens - self.cos_sin_rerotation_cache = {} - self._cos_cache = None - self._sin_cache = None - self._seen_tokens = ( - 0 # Used in `generate` to keep tally of how many tokens the cache has seen - ) - - @staticmethod - def _rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_key_rotary_pos_emb( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> torch.Tensor: - rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) - return rotated_key_states - - def _get_rerotation_cos_sin( - self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - if key_states.shape[-2] not in self.cos_sin_rerotation_cache: - # Upcast to float32 temporarily for better accuracy - cos = cos.to(torch.float32) - sin = sin.to(torch.float32) - - # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence - original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] - shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] - original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] - shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] - rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin - rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin - - self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( - rerotation_cos.to(key_states.dtype).unsqueeze(0), - rerotation_sin.to(key_states.dtype).unsqueeze(0), - ) - return self.cos_sin_rerotation_cache[key_states.shape[-2]] - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length - if len(self.key_cache) <= layer_idx: - return 0 - return self.key_cache[layer_idx].shape[-2] - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" - return self.window_length - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, - `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the - rotation as the tokens are shifted. - - Return: - A tuple containing the updated key and value states. - """ - # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models - # with partially rotated position embeddings, like Phi or Persimmon. - if cache_kwargs is None: - cache_kwargs = {} - sin = cache_kwargs.get("sin") - cos = cache_kwargs.get("cos") - partial_rotation_size = cache_kwargs.get("partial_rotation_size") - using_rope = cos is not None and sin is not None - - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the sin/cos cache, which holds sin/cos values for all possible positions - if using_rope and layer_idx == 0: - # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove - # after all RoPE models have a llama-like cache utilization. - if cos.dim() == 2: - self._cos_cache = cos - self._sin_cache = sin - elif self._cos_cache is None: - self._cos_cache = cos[0, ...] - self._sin_cache = sin[0, ...] - elif self._cos_cache.shape[0] < self.window_length: - self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) - self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) - - # [bsz, num_heads, seq_len, head_dim] - if len(self.key_cache) <= layer_idx: - # Empty cache - self.key_cache.append(key_states) - self.value_cache.append(value_states) - - elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: - # Growing cache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx], value_states], dim=-2 - ) - - else: - # Shifting cache - keys_to_keep = self.key_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : - ] - - # On RoPE models, we need to recompute the Key rotation as the tokens are shifted - if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( - key_states, - self._cos_cache[: self.window_length], - self._sin_cache[: self.window_length], - ) - if partial_rotation_size is not None: - keys_to_keep, keys_pass = ( - keys_to_keep[..., :partial_rotation_size], - keys_to_keep[..., partial_rotation_size:], - ) - keys_to_keep = self._apply_key_rotary_pos_emb( - keys_to_keep, rerotation_cos, rerotation_sin - ) - if partial_rotation_size is not None: - keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) - - # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens - sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] - self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) - - sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] - values_to_keep = self.value_cache[layer_idx][ - :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : - ] - self.value_cache[layer_idx] = torch.cat( - [sink_values, values_to_keep, value_states], dim=-2 - ) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - -class StaticCache(Cache): - """ - Static Cache class to be used with `torch.compile(model)` and `torch.export()`. - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. If you are manually setting the batch size, make sure to take into account the - number of beams if you are running beam search - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") - - >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - StaticCache() - ``` - """ - - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.float32, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super().__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = ( - config.max_position_embeddings if max_cache_len is None else max_cache_len - ) - - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - # Note: There will be significant perf decrease if switching to use 5D tensors instead. - cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - self.max_cache_len, - self.head_dim, - ) - device = torch.device(device) if device is not None else None - for idx in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[idx] - else: - layer_device = device - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input - to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place - # operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - return k_out, v_out - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class SlidingWindowCache(StaticCache): - """ - Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention. - Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`, - if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint), - we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in. - - The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`: - - indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window - tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, - 55, 56, 57, 58, 59, 60, 61, 62, 63, 0]) - - We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`) - - Parameters: - config (`PretrainedConfig`): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache - - >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3") - - >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - SlidingWindowCache() - ``` - """ - - is_sliding = True - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.float32, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - max_cache_len = min(config.sliding_window, max_cache_len) - super().__init__( - config=config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=device, - dtype=dtype, - layer_device_map=layer_device_map, - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len) - if cache_position.shape[0] > self.max_cache_len: - k_out = key_states[:, :, -self.max_cache_len :, :] - v_out = value_states[:, :, -self.max_cache_len :, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones( - self.max_cache_len, dtype=torch.long, device=value_states.device - ).cumsum(0) - cache_position = cache_position.clamp(0, self.max_cache_len - 1) - to_shift = cache_position >= self.max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len - - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - - return k_out, v_out - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def reset(self): - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class EncoderDecoderCache(Cache): - """ - Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and - cross-attention caches. - - Example: - - ```python - >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small") - >>> processor = AutoProcessor.from_pretrained("openai/whisper-small") - - >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt") - - >>> # Prepare cache classes for encoder and decoder and pass it to model's forward - >>> self_attention_cache = DynamicCache() - >>> cross_attention_cache = DynamicCache() - >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - EncoderDecoderCache() - ``` - - """ - - def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): - super().__init__() - self.self_attention_cache = self_attention_cache - self.cross_attention_cache = cross_attention_cache - self.is_compileable = getattr(self.self_attention_cache, "is_compileable", False) - - self.is_updated = {} - for layer_idx in range(len(cross_attention_cache.key_cache)): - self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) - - def __getitem__(self, layer_idx: int) -> list[tuple[torch.Tensor]]: - """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the - sequence length. - """ - if layer_idx < len(self): - return ( - self.self_attention_cache.key_cache[layer_idx], - self.self_attention_cache.value_cache[layer_idx], - self.cross_attention_cache.key_cache[layer_idx], - self.cross_attention_cache.value_cache[layer_idx], - ) - else: - raise KeyError( - f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" - ) - - def __len__(self): - """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds - to the number of layers in the model. - """ - return len(self.self_attention_cache) - - def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: - """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" - legacy_cache = () - if len(self.cross_attention_cache) > 0: - for self_attn, cross_attn in zip( - self.self_attention_cache.to_legacy_cache(), - self.cross_attention_cache.to_legacy_cache(), - ): - legacy_cache += (self_attn + cross_attn,) - else: - legacy_cache = self.self_attention_cache.to_legacy_cache() - return legacy_cache - - @classmethod - def from_legacy_cache( - cls, past_key_values: tuple[tuple[torch.FloatTensor]] | None = None - ) -> "EncoderDecoderCache": - """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" - cache = cls( - self_attention_cache=DynamicCache(), - cross_attention_cache=DynamicCache(), - ) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx][:2] - cache.self_attention_cache.update(key_states, value_states, layer_idx) - if len(past_key_values[layer_idx]) > 2: - key_states, value_states = past_key_values[layer_idx][2:] - cache.cross_attention_cache.update(key_states, value_states, layer_idx) - cache.is_updated[layer_idx] = True - return cache - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - return self.self_attention_cache.get_seq_length(layer_idx) - - def reset(self): - if hasattr(self.self_attention_cache, "reset"): - self.self_attention_cache.reset() - if hasattr(self.cross_attention_cache, "reset"): - self.cross_attention_cache.reset() - elif not hasattr(self.self_attention_cache, "reset") and not hasattr( - self.cross_attention_cache, "reset" - ): - raise ValueError( - "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " - "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " - f"Got {self.self_attention_cache.__str__()} for the self attention cache and " - f"{self.cross_attention_cache.__str__()} for the cross attention cache." - ) - for layer_idx in self.is_updated: - self.is_updated[layer_idx] = False - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - self.self_attention_cache.reorder_cache(beam_idx) - self.cross_attention_cache.reorder_cache(beam_idx) - - def check_dynamic_cache(self, method: str): - if not ( - isinstance(self.self_attention_cache, DynamicCache) - and isinstance(self.cross_attention_cache, DynamicCache) - ): - raise ValueError( - f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " - f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." - ) - - # TODO(gante, sanchit-gandhi): move following functionality into `.generate` - def crop(self, maximum_length: int): - """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be - negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" - self.check_dynamic_cache(self.crop.__name__) - self.self_attention_cache.crop(maximum_length) - - def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]": - """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by - `_split_model_inputs()` in `generation.utils`""" - self.check_dynamic_cache(self.batch_split.__name__) - self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) - cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) - - out = [] - for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): - out.append(EncoderDecoderCache(self_attn, cross_attn)) - return out - - @classmethod - def from_batch_splits(cls, splits: list["EncoderDecoderCache"]) -> "EncoderDecoderCache": - """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in - `generation.utils`""" - self_attention_cache = DynamicCache() - cross_attention_cache = DynamicCache() - for idx in range(len(splits[0])): - layer_keys = torch.cat( - [current.self_attention_cache.key_cache[idx] for current in splits], dim=0 - ) - layer_values = torch.cat( - [current.self_attention_cache.value_cache[idx] for current in splits], dim=0 - ) - self_attention_cache.update(layer_keys, layer_values, idx) - - layer_keys = torch.cat( - [current.cross_attention_cache.key_cache[idx] for current in splits], dim=0 - ) - layer_values = torch.cat( - [current.cross_attention_cache.value_cache[idx] for current in splits], dim=0 - ) - cross_attention_cache.update(layer_keys, layer_values, idx) - return cls(self_attention_cache, cross_attention_cache) - - def batch_repeat_interleave(self, repeats: int): - """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_repeat_interleave.__name__) - self.self_attention_cache.batch_repeat_interleave(repeats) - self.cross_attention_cache.batch_repeat_interleave(repeats) - - def batch_select_indices(self, indices: torch.Tensor): - """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" - self.check_dynamic_cache(self.batch_select_indices.__name__) - self.self_attention_cache.batch_select_indices(indices) - self.cross_attention_cache.batch_select_indices(indices) - - -class HybridCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention - and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention - and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. - - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.float32`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert - # ALL changes from the PR that commented the line below when reactivating it. - # is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.float32, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - raise ValueError( - "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " - "sliding window attention, please check if there is a `sliding_window` field in the model " - "config and it's not set to None." - ) - self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - self.head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - self._dtype = dtype - self.num_key_value_heads = ( - config.num_attention_heads - if config.num_key_value_heads is None - else config.num_key_value_heads - ) - - layer_switch = ( - config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 - ) # 2 is for BC - self.is_sliding = torch.tensor( - [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], - dtype=torch.bool, - ) - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - global_cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - max_cache_len, - self.head_dim, - ) - sliding_cache_shape = ( - self.max_batch_size, - self.num_key_value_heads, - min(config.sliding_window, max_cache_len), - self.head_dim, - ) - device = torch.device(device) if device is not None and isinstance(device, str) else None - for i in range(config.num_hidden_layers): - if layer_device_map is not None: - layer_device = layer_device_map[i] - else: - layer_device = device - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def _sliding_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - if cache_position.shape[0] > max_cache_len: - k_out = key_states[:, :, -max_cache_len:, :] - v_out = value_states[:, :, -max_cache_len:, :] - # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return key_states, value_states - - slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) - cache_position = cache_position.clamp(0, max_cache_len - 1) - to_shift = cache_position >= max_cache_len - 1 - indices = (slicing + to_shift[-1].int() - 1) % max_cache_len - k_out = k_out[:, :, indices] - v_out = v_out[:, :, indices] - - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - self.key_cache[layer_idx] += k_out - self.value_cache[layer_idx] += v_out - return k_out, v_out - - def _static_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - sliding_window = cache_kwargs.get("sliding_window") - - # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used - # when the cache is initialized in the forward pass (e.g. Gemma2) - if self.key_cache[layer_idx].device != key_states.device: - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) - if self.value_cache[layer_idx].device != value_states.device: - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if sliding_window: - update_fn = self._sliding_update - else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def get_seq_length(self, layer_idx: int | None = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - -class HybridChunkedCache(Cache): - """ - Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention - and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention - and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. - - Parameters: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. - max_cache_len (`int`, *optional*): - The maximum sequence length with which the model will be used. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`): - The default `dtype` to use when initializing the layer. - layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is split between different gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache - - >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") - - >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values # access cache filled with key/values from generation - HybridCache() - ``` - """ - - # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert - # ALL changes from the PR that commented the line below when reactivating it. - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None = None, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.bfloat16, - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super().__init__() - if not hasattr(config, "sliding_window") or config.sliding_window is None: - self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192) - else: - self.sliding_window = config.sliding_window - self.max_cache_len = max_cache_len - self.max_batch_size = max_batch_size - self.head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - self._dtype = dtype - - if hasattr(config.get_text_config(), "no_rope_layers"): - self.is_sliding = config.no_rope_layers - else: - layer_switch = getattr(config, "sliding_window_pattern", 2) - self.is_sliding = [ - bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers) - ] - - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - self.cumulative_length = [0 for _ in range(config.num_hidden_layers)] - - def initialise_cache_layer(self, layer_idx, key_states): - if len(self.key_cache) > layer_idx: - return - - num_key_value_heads = key_states.shape[1] - device = key_states.device - global_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.max_cache_len, - self.head_dim, - ) - sliding_cache_shape = ( - self.max_batch_size, - num_key_value_heads, - self.sliding_window, - self.head_dim, - ) - # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph - # breaks when updating the cache. - cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape - new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device) - torch._dynamo.mark_static_address(new_layer_key_cache) - torch._dynamo.mark_static_address(new_layer_value_cache) - self.key_cache.append(new_layer_key_cache) - self.value_cache.append(new_layer_value_cache) - - def _sliding_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - cumulative_length = self.cumulative_length[layer_idx] - # Update it now that we saved the value above - self.cumulative_length[layer_idx] += key_states.shape[-2] - is_full = cumulative_length >= max_cache_len - if is_full: - full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2) - full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2) - # Fast decoding path -> here as the effective size is still sliding window, it is extremely important - # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed adress - # in memory (the values are the same as the full states, but not the address!!) - if key_states.shape[-2] == 1: - self.key_cache[layer_idx].copy_(full_key_states) - self.value_cache[layer_idx].copy_(full_value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len: - # Fast prefill path, no need to cat() in this case (which creates a copy even if cating from 0 dim) - if cumulative_length == 0: - full_key_states = key_states - full_value_states = value_states - else: - full_key_states = torch.cat( - (k_out[:, :, :cumulative_length, :], key_states), dim=-2 - ) - full_value_states = torch.cat( - (v_out[:, :, :cumulative_length, :], value_states), dim=-2 - ) - else: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :]) - self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :]) - # we should return the whole states instead of k_out, v_out to take the whole prompt - # into consideration when building kv cache instead of just throwing away tokens outside of the window - return full_key_states, full_value_states - - def _static_update( - self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len - ): - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - self.key_cache[layer_idx] = k_out - self.value_cache[layer_idx] = v_out - return k_out, v_out - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if cache_kwargs is None: - cache_kwargs = {} - cache_position = cache_kwargs.get("cache_position") - self.initialise_cache_layer(layer_idx, key_states) - - k_out = self.key_cache[layer_idx] - v_out = self.value_cache[layer_idx] - key_states = key_states.to(k_out.dtype) - value_states = value_states.to(v_out.dtype) - - if self.is_sliding[layer_idx]: - update_fn = self._sliding_update - else: - update_fn = self._static_update - - return update_fn( - cache_position, - layer_idx, - key_states, - value_states, - k_out, - v_out, - k_out.shape[2], - ) - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def get_seq_length(self, layer_idx: int | None = 0): - # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's - # limit the check to the first batch member and head dimension. - # TODO: deprecate this function in favor of `cache_position` - if layer_idx != 0: - raise ValueError( - "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " - "Using the `layer_idx` argument is not supported." - ) - if len(self.key_cache) == 0: - return 0 - return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() - - def reset(self): - """Resets the cache values while preserving the objects""" - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - self.cumulative_length = [0 for _ in range(len(self.cumulative_length))] - - -class MambaCache: - """ - Cache for mamba model which does not have attention mechanism and key value states. - - Arguments: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. - dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): - The default `dtype` to use when initializing the layer. - device (`torch.device` or `str`, *optional*): - The device on which the cache should be initialized. Should be the same as the layer. - - Example: - - ```python - >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache - - >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf") - - >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> outputs.past_key_values - MambaCache() - ``` - """ - - is_compileable = True - - # TODO (joao): add layer_device_map arg and update code in `generate` accordingly - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - dtype: torch.dtype = torch.float16, - device: torch.device | str | None = None, - ): - self.max_batch_size = max_batch_size - self._dtype = dtype - self.intermediate_size = config.intermediate_size - self.ssm_state_size = config.state_size - self.conv_kernel_size = config.conv_kernel - - self.conv_states: list[torch.Tensor] = [] - self.ssm_states: list[torch.Tensor] = [] - device = torch.device(device) if device is not None else None - for _ in range(config.num_hidden_layers): - conv_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.conv_kernel_size, - device=device, - dtype=self._dtype, - ) - ssm_state: torch.Tensor = torch.zeros( - self.max_batch_size, - self.intermediate_size, - self.ssm_state_size, - device=device, - dtype=self._dtype, - ) - - torch._dynamo.mark_static_address(conv_state) - torch._dynamo.mark_static_address(ssm_state) - self.conv_states.append(conv_state) - self.ssm_states.append(ssm_state) - - def update_conv_state( - self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor - ) -> torch.Tensor: - # This `if` blocks is only reached in multigpu and if `layer_device_map` is not passed. It is used - # when the cache is initialized in the forward pass (e.g. Mamba) - if self.conv_states[layer_idx].device != new_conv_state.device: - self.conv_states[layer_idx] = self.conv_states[layer_idx].to(new_conv_state.device) - - conv_state = self.conv_states[layer_idx] - cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) - - conv_state = conv_state.roll(shifts=-1, dims=-1) - conv_state[:, :, cache_position] = new_conv_state.to( - device=conv_state.device, dtype=conv_state.dtype - ) - self.conv_states[layer_idx].zero_() - self.conv_states[layer_idx] += conv_state - return self.conv_states[layer_idx] - - def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[layer_idx].device) - return self.ssm_states[layer_idx] - - def reset(self): - for layer_idx in range(len(self.conv_states)): - # In-place ops prevent breaking the static address - self.conv_states[layer_idx].zero_() - self.ssm_states[layer_idx].zero_() - - -class OffloadedStaticCache(StaticCache): - """ - Static cache class to be used with `torch.compile(model)` that offloads to the CPU or - another device. - - Args: - config (`PretrainedConfig): - The configuration file defining the shape-related attributes required to initialize - the static cache. - max_batch_size (`int`): - The maximum batch size with which the model will be used. - max_cache_len (`int`): - The maximum sequence length with which the model will be used. - device (`Union[str, torch.device]`): - The device on which the cache should be initialized. If you're using more than 1 computation device, you - should pass the `layer_device_map` argument instead. - dtype (`torch.dtype`, *optional*): - The default `dtype` to use when initializing the cache. - offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`): - The device to offload to. Defaults to CPU. - layer_device_map (`Dict[int, Union[str, torch.device, int]]`, *optional*): - Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is splitted between differents gpus. You can know which layers mapped to which device by - checking the associated device_map: `model.hf_device_map`. - - Example: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache - - >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - - >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") - - >>> # Prepare a cache class and pass it to model's forward - >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate - >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) - >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) - >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation - ``` - """ - - is_compileable = True - - def __init__( - self, - config: PretrainedConfig, - max_batch_size: int, - max_cache_len: int | None, - device: str | torch.device, - dtype: torch.dtype | None = None, - offload_device: str | torch.device = torch.device("cpu"), - layer_device_map: dict[int, str | torch.device | int] | None = None, - ) -> None: - super(Cache, self).__init__() - self.max_batch_size = max_batch_size - self.max_cache_len = ( - config.max_position_embeddings if max_cache_len is None else max_cache_len - ) - self.device = ( - torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0]) - ) - self.offload_device = torch.device(offload_device) - self._dtype = dtype if dtype is not None else torch.float32 - - # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads - head_dim = ( - config.head_dim - if hasattr(config, "head_dim") - else config.hidden_size // config.num_attention_heads - ) - - num_key_value_heads = ( - config.num_attention_heads - if getattr(config, "num_key_value_heads", None) is None - else config.num_key_value_heads - ) - - cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim) - - # Create offloaded CPU tensors. - self.key_cache: list[torch.Tensor] = [] - self.value_cache: list[torch.Tensor] = [] - - for i in range(config.num_hidden_layers): - # First layer is always on-device. - device = self.device if i == 0 else self.offload_device - - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device) - - self.key_cache.append(key_cache) - self.value_cache.append(value_cache) - - # Create device tensors. - self._device_key_cache: list[torch.Tensor] = [] - self._device_value_cache: list[torch.Tensor] = [] - - for i in range(2): - key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device) - - self._device_key_cache.append(key_cache) - self._device_value_cache.append(value_cache) - - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 - - # Create new CUDA stream for parallel prefetching. - self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - It is VERY important to index using a tensor, otherwise you introduce a copy to the device. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, *optional*): - Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the - `cache_position` input to know how where to write in the cache. - - Return: - A tuple containing the updated key and value states. - """ - - if layer_idx == 0: - # Update seen tokens. - # TODO(gante): Remove this. - self._seen_tokens += key_states.shape[-2] - - # Always there. - k_out = self.key_cache[0] - v_out = self.value_cache[0] - else: - # Wait for prefetch stream. - if self._prefetch_stream is not None: - torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream) - - k_out = self._device_key_cache[layer_idx & 1] - v_out = self._device_value_cache[layer_idx & 1] - - self._prefetch_layer(layer_idx + 1) - - cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None - if cache_position is None: - k_out.copy_(key_states) - v_out.copy_(value_states) - - # Copy the values to the offloaded device as well. - if layer_idx == 0: - self.key_cache[layer_idx].copy_(key_states.to(self.offload_device)) - self.value_cache[layer_idx].copy_(value_states.to(self.offload_device)) - else: - # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to - # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does - # explicitly an in-place operation, that avoids copies and uses less memory. - try: - k_out.index_copy_(2, cache_position, key_states) - v_out.index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states - - # Copy the values to the offloaded device as well. - if layer_idx != 0: - cache_position = cache_position.to(self.offload_device) - key_states = key_states.to(self.offload_device) - value_states = value_states.to(self.offload_device) - - try: - self.key_cache[layer_idx].index_copy_(2, cache_position, key_states) - self.value_cache[layer_idx].index_copy_(2, cache_position, value_states) - except NotImplementedError: - # The operator 'aten::index_copy.out' is not currently implemented for the MPS - # device. - self.key_cache[layer_idx][:, :, cache_position] = key_states - self.value_cache[layer_idx][:, :, cache_position] = value_states - - return k_out, v_out - - def get_seq_length(self, layer_idx: int | None = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model.""" - - # TODO(gante): Remove this. - return self._seen_tokens - - def get_max_cache_shape(self) -> int | None: - """Returns the maximum sequence length of the cached states.""" - - return self.max_cache_len - - def reset(self) -> None: - """Resets the cache values while preserving the objects.""" - - # For backwards compatibility. - # TODO(gante): Remove this. - self._seen_tokens = 0 - - # Zero out cache. - for layer_idx in range(len(self.key_cache)): - # In-place ops prevent breaking the static address. - self.key_cache[layer_idx].zero_() - self.value_cache[layer_idx].zero_() - - @property - def seen_tokens(self) -> int: - # For backwards compatibility. - # TODO(gante): Remove this. - return self._seen_tokens - - def _create_key_value_cache_tensors( - self, shape: tuple[int, ...], device: torch.device - ) -> tuple[torch.Tensor, torch.Tensor]: - """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static - addresses for non-CPU tensors. - - Args: - shape (`Tuple[int, ...]`): Shape. - device (`torch.device`): Device. - - Returns: - Key and value cache tensors as a tuple. - """ - - is_cpu_device = device == torch.device("cpu") - - key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device) - - # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, - # preventing compiled graph breaks when updating the cache. - torch._dynamo.mark_static_address(key_cache) - torch._dynamo.mark_static_address(value_cache) - - return key_cache, value_cache - - def _prefetch_layer(self, layer_idx: int) -> None: - """Prefetch a layer to the device. Needs to be called in order of layer indices.""" - - # Don't fetch layers that do not exist. - if layer_idx >= len(self.key_cache): - return - - # Alternate between two on-device caches. - if self._prefetch_stream is not None: - with torch.cuda.stream(self._prefetch_stream): - self._prefetch_layer_in_context(layer_idx) - else: - self._prefetch_layer_in_context(layer_idx) - - def _prefetch_layer_in_context(self, layer_idx: int) -> None: - """Performs the actual copy of the layer to device cache.""" - - self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True) - self._device_value_cache[layer_idx & 1].copy_( - self.value_cache[layer_idx], non_blocking=True - ) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py deleted file mode 100644 index b17883628ff..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/transformers_4_51_3__modeling_llama4_attention.py +++ /dev/null @@ -1,289 +0,0 @@ -# coding=utf-8 -# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors -import math -from typing import Callable, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint - -from transformers.cache_utils import Cache -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.processing_utils import Unpack -from transformers.utils import ( - is_torch_flex_attn_available, - logging, -) -from .transformers_4_51_3__configuration_llama4 import Llama4TextConfig - - -if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask - - from transformers.integrations.flex_attention import make_flex_block_causal_mask - -logger = logging.get_logger(__name__) - - -class Llama4TextL2Norm(torch.nn.Module): - def __init__(self, eps: float = 1e-6): - super().__init__() - self.eps = eps - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - return self._norm(x.float()).type_as(x) - - def extra_repr(self): - return f"eps={self.eps}" - - -class Llama4TextRotaryEmbedding(nn.Module): - def __init__(self, config: Llama4TextConfig, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - self.rope_type = "llama3" if config.rope_scaling is not None else "default" - - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - freqs_cis = freqs_cis * self.attention_scaling - return freqs_cis - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - xq_out = torch.view_as_real(xq_ * freqs_cis[:, :, None, :]).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - scaling: float, - dropout: float = 0.0, - **kwargs, -): - # print(f"{module.layer_idx=} {module.num_key_value_groups=}") - # print(f"{module.layer_idx=} {module.head_dim=}") - # print(f"{module.layer_idx=} {module.training=}") - # print(f"{scaling=}") - # print(f"{dropout=}") - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) / math.sqrt(module.head_dim) - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -class Llama4TextAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Llama4TextConfig, layer_idx, use_rope: bool): # we added use_rope to not be dependent on the layer index - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_attention_heads = config.num_attention_heads - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.num_key_value_heads = config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attn_scale = config.attn_scale - self.floor_scale = config.floor_scale - self.attn_temperature_tuning = config.attn_temperature_tuning - self.attention_dropout = config.attention_dropout - self.is_causal = True - # self.use_rope = int((layer_idx + 1) % 4 != 0) # rope unused for dense layers - self.use_rope = use_rope - self.q_proj = nn.Linear( - config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias - ) - self.k_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias - ) - if self.config.use_qk_norm and self.use_rope: - self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: Tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape) - key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim) - value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - if self.use_rope: # the 16E model skips rope for long context on certain layers - query_states, key_states = apply_rotary_emb( - query_states, key_states, position_embeddings.to(query_states.device) - ) - - if hasattr(self, "qk_norm"): # the 128E model does not use qk_norm - query_states = self.qk_norm(query_states) - key_states = self.qk_norm(key_states) - - # Use temperature tuning from https://arxiv.org/abs/2501.19399) to NoROPE layers - if self.attn_temperature_tuning and not self.use_rope: - device = query_states.device - attn_scales = ( - torch.log(torch.floor((cache_position.float() + 1.0) / self.floor_scale) + 1.0) * self.attn_scale + 1.0 - ).to(device) - attn_scales = attn_scales.view((1, input_shape[-1], 1, 1)).expand((*input_shape, 1, 1)) # batch size > 1 - query_states = (query_states * attn_scales).to(query_states.dtype) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - # print(f"{self.layer_idx=} {cache_position=} {attention_mask=}") - # print(f"{self.layer_idx=} {query_states.flatten()[:10]=}") - # print(f"{self.layer_idx=} {key_states.flatten()[:10]=}") - # print(f"{self.layer_idx=} {value_states.flatten()[:10]=}") - # print(f"{self.layer_idx=} {kwargs=}") - # print(f"{self.layer_idx=} {attention_interface=}") - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py deleted file mode 100644 index 9acc27eb9f5..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/variable_cache.py +++ /dev/null @@ -1,213 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# mypy: ignore-errors -from copy import deepcopy -from typing import Any - -import torch -from transformers.cache_utils import ( - Cache, # used to let GenerationMixin know that we use a Cache object -) - -from .configuration_decilm import DeciLMConfig -from .transformers_4_44_2__cache_utils import Cache as Cache_4_44_2 -from .transformers_4_44_2__cache_utils import SinkCache, SlidingWindowCache, StaticCache -from .transformers_4_51_3__cache_utils import HybridChunkedCache - -LayerIndex = tuple[ - int, ... -] # supports both regular transformer blocks and parallel transformer multi-blocks - - -class VariableCache(Cache_4_44_2, Cache): - """ - A Cache object that supports a different Cache implementation for every layer, - including layers without any kv-cache. - Implemented using a list of Cache objects, each represents a "model" with 1 layer. - The default implementation for the layer caches is StaticCache. - The cache of each layer is allocated to the same gpu as the layer itself. - """ - - def __init__( - self, - *, # key-word only, no positional args allowed to avoid mix-ups with newer transformers versions - config: DeciLMConfig, - batch_size: int | None = None, - max_cache_len: int | None = None, - dtype: torch.dtype = torch.get_default_dtype(), - max_batch_size: int | None = None, - **kwargs, - ) -> None: - Cache_4_44_2.__init__(self) - - self.config = deepcopy(config) - self.max_batch_size = batch_size or max_batch_size - self.batch_size = self.max_batch_size - self.max_cache_len = ( - config.max_position_embeddings if (max_cache_len is None) else max_cache_len - ) - self.dtype = dtype - - self.layer_caches: dict[LayerIndex, Cache_4_44_2] = {} - self.layer_devices: dict[LayerIndex, torch.device] = {} - - def __repr__(self): - return ( - f"VariableCache:\n" - f"==============\n" - f"max_batch_size={self.max_batch_size}\n" - f"batch_size={self.batch_size}\n" - f"max_cache_len={self.max_cache_len}\n" - f"dtype={self.dtype}\n" - f"layer_caches={self.layer_caches}\n" - f"layer_devices={self.layer_devices}\n" - f"==============\n" - ) - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int | LayerIndex, - cache_kwargs: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - if isinstance(layer_idx, int): - layer_idx = _int_to_layer_index(layer_idx) - - if layer_idx not in self.layer_caches: - self.layer_devices[layer_idx] = key_states.device - self._init_layer_cache(layer_idx) - - layer_cache = self.layer_caches[layer_idx] - assert layer_cache is not None, ( - f"Trying to update the cache of a cache-less layer: {layer_idx=}" - ) - - k_out, v_out = layer_cache.update( - key_states=key_states, value_states=value_states, layer_idx=0, cache_kwargs=cache_kwargs - ) - - input_seq_len = key_states.shape[2] # [batch_size, num_kv_heads, seq_len, hidden_size] - cache_seq_len = self.get_seq_length(layer_idx) - seq_len = max(input_seq_len, cache_seq_len) - - k_out = k_out[:, :, :seq_len, :] - v_out = v_out[:, :, :seq_len, :] - return k_out, v_out - - def _init_layer_cache(self, layer_idx: LayerIndex) -> None: - block_config = self.config.get_block_config(layer_idx) - attention_config = block_config.attention - - if attention_config.no_op or attention_config.replace_with_linear: - return None - - device = self.layer_devices[layer_idx] - assert device is not None, f"Trying to init layer cache for {layer_idx=} without device" - - config = deepcopy(self.config) - config.num_hidden_layers = 1 - config.num_key_value_heads = ( - self.config.num_attention_heads // attention_config.n_heads_in_group - ) - - if attention_config.is_llama4: - attention_chunk_size = attention_config.llama4.attention_chunk_size - is_chunked = attention_chunk_size is not None - config.no_rope_layers = [int(is_chunked)] - config.attention_chunk_size = ( - attention_chunk_size if is_chunked else config.get_min_attention_chunk_size() - ) - self.layer_caches[layer_idx] = HybridChunkedCache( - config=config, - max_batch_size=self.max_batch_size, - max_cache_len=self.max_cache_len, - dtype=self.dtype, - ) - return - - if attention_config.window_length is not None: - if not attention_config.is_sink: - config.sliding_window = attention_config.window_length - self.layer_caches[layer_idx] = SlidingWindowCache( - config=config, - max_batch_size=self.max_batch_size, - max_cache_len=self.max_cache_len, - device=device, - dtype=self.dtype, - ) - return - elif not attention_config.unshifted_sink: - self.layer_caches[layer_idx] = SinkCache( - window_length=attention_config.window_length, - num_sink_tokens=attention_config.num_sink_tokens, - ) - return - - self.layer_caches[layer_idx] = StaticCache( - config=config, - max_batch_size=self.max_batch_size, - max_cache_len=self.max_cache_len, - device=device, - dtype=self.dtype, - ) - - def _get_arbitrary_cache(self) -> Cache_4_44_2: - if len(self.layer_caches) == 0: - raise NoCacheFoundError() - layer_cache = next(iter(self.layer_caches.values())) - return layer_cache - - def get_seq_length(self, layer_idx: int | LayerIndex | None = 0) -> int: - """default 0 to match standard HF implementation""" - if (layer_idx is None) or ( - layer_idx == 0 and _int_to_layer_index(0) not in self.layer_caches - ): - try: - layer_cache = self._get_arbitrary_cache() - return layer_cache.get_seq_length() - except NoCacheFoundError: - return 0 - - if isinstance(layer_idx, int): - layer_idx = _int_to_layer_index(layer_idx) - - layer_cache = self.layer_caches[layer_idx] - return layer_cache.get_seq_length() - - def get_max_length(self) -> int | None: - """Returns the maximum sequence length of the cached states.""" - return self.max_cache_len - - def get_max_cache_shape(self) -> int | None: - return self.max_cache_len - - def reset(self): - for layer_idx, layer_cache in self.layer_caches.items(): - if hasattr(layer_cache, "reset"): - layer_cache.reset() - else: - self.layer_caches[layer_idx] = None - self.layer_devices[layer_idx] = None - # self._init_layer_cache(layer_idx) - - -class NoCacheFoundError(Exception): - pass - - -def _int_to_layer_index(layer_idx: int) -> LayerIndex: - return (layer_idx,) diff --git a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py b/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py deleted file mode 100644 index 4c8f86cdbca..00000000000 --- a/modelopt/torch/puzzletron/decilm/deci_lm_hf_code/vllm_yarn_utils.py +++ /dev/null @@ -1,210 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import torch -import torch.nn as nn - - -def _apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - is_neox_style: bool, -) -> torch.Tensor: - cos = cos.unsqueeze(-2).to(x.dtype) - sin = sin.unsqueeze(-2).to(x.dtype) - if is_neox_style: - x1, x2 = torch.chunk(x, 2, dim=-1) - else: - x1 = x[..., ::2] - x2 = x[..., 1::2] - o1 = x1 * cos - x2 * sin - o2 = x2 * cos + x1 * sin - if is_neox_style: - return torch.cat((o1, o2), dim=-1) - else: - return torch.stack((o1, o2), dim=-1).flatten(-2) - - -class RotaryEmbedding(nn.Module): - """Original rotary positional embedding.""" - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - dtype: torch.dtype, - ) -> None: - super().__init__() - self.head_size = head_size - self.rotary_dim = rotary_dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.is_neox_style = is_neox_style - self.dtype = dtype - - cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) - self.cos_sin_cache: torch.Tensor - self.register_buffer("cos_sin_cache", cache, persistent=False) - - def _compute_inv_freq(self, base: int | float) -> torch.Tensor: - """Compute the inverse frequency.""" - # NOTE(woosuk): To exactly match the HF implementation, we need to - # use CPU to compute the cache and then move it to GPU. However, we - # create the cache on GPU for faster initialization. This may cause - # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim) - ) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - """Compute the cos and sin cache.""" - inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache - - def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation of forward().""" - if offsets is not None: - positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions) - cos, sin = cos_sin.chunk(2, dim=-1) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = _apply_rotary_emb_torch(query_rot, cos, sin, self.is_neox_style) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = _apply_rotary_emb_torch(key_rot, cos, sin, self.is_neox_style) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - - -def _yarn_get_mscale(scale: float = 1) -> float: - if scale <= 1: - return 1.0 - return 0.1 * math.log(scale) + 1.0 - - -# Inverse dim formula to find dim based on number of rotations -def _yarn_find_correction_dim( - num_rotations: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 -) -> float: - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) - - -def _yarn_find_correction_range( - low_rot: int, high_rot: int, dim: int, base: float = 10000, max_position_embeddings: int = 2048 -) -> tuple[int, int]: - low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype) -> torch.Tensor: - if low == high: - high += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -class YaRNScalingRotaryEmbedding(RotaryEmbedding): - """RotaryEmbedding extended with YaRN method. - - Credits to Peng et al. github.com/jquesnelle/yarn - """ - - def __init__( - self, - head_size: int, - rotary_dim: int, - max_position_embeddings: int, - base: int, - is_neox_style: bool, - scaling_factor: float, - dtype: torch.dtype, - *, - extrapolation_factor: float = 1, - attn_factor: float = 1, - beta_fast: int = 32, - beta_slow: int = 1, - ) -> None: - self.scaling_factor = scaling_factor - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.beta_fast = beta_fast - self.beta_slow = beta_slow - # Get n-d magnitude scaling corrected for interpolation - self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) - super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) - - def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim - ) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) - - low, high = _yarn_find_correction_range( - self.beta_fast, self.beta_slow, self.rotary_dim, self.base, self.max_position_embeddings - ) - # print(f"low: {low}, high: {high}") - # Get n-d rotational scaling corrected for extrapolation - inv_freq_mask = ( - 1 - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) - ) * self.extrapolation_factor - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - ) - return inv_freq - - def _compute_cos_sin_cache(self) -> torch.Tensor: - inv_freq = self._compute_inv_freq(self.scaling_factor) - t = torch.arange(self.max_position_embeddings * self.scaling_factor, dtype=torch.float32) - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * self.mscale - sin = freqs.sin() * self.mscale - cache = torch.cat((cos, sin), dim=-1) - return cache diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_library.py b/modelopt/torch/puzzletron/replacement_library/replacement_library.py index 73661edba55..f0d5bb05839 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_library.py @@ -13,55 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Replacement library for efficiently loading and managing layer-replaced DeciLM models. -- Uses replacement_utils for parsing, sorting, and analyzing layer replacement configurations +Replacement library for loading models with layer replacements (AnyModel / sharded HF checkpoints). """ # mypy: ignore-errors import copy import json -import re import tempfile from pathlib import Path from typing import List, Optional -import torch from immutabledict import immutabledict -from lru import LRU from safetensors import safe_open -from safetensors.torch import load_file as safe_load_file -from torch import nn from transformers import PretrainedConfig, PreTrainedModel -import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.anymodel.converter.converter import Converter from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.modeling_decilm import ( - DeciLMDecoderLayer, - DeciLMMultiDecoderLayer, - DeciLMRMSNorm, - LMHead, -) from modelopt.torch.puzzletron.replacement_library.replacement_utils import ( extract_block_configs_and_locations, parse_layer_replacement, - sort_replacements, weights_path_to_checkpoint_dir, ) from modelopt.torch.puzzletron.tools.checkpoint_utils import ( - PTH_SUBBLOCKS_DIR_NAME, SAFETENSORS_SUBBLOCKS_DIR_NAME, - infer_weights_dtype, - init_empty_module, - init_module_with_state_dict, load_model_config, ) from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_model_config -from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import ( - is_in_safetensors_format, - load_and_shard_model, - load_sharded_state_dict, -) +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model class ReplacementLibrary: @@ -78,17 +56,7 @@ def __init__( immutabledict(model_config_overrides) if (model_config_overrides is not None) else None ) - self._loaded_replacements: dict[str, nn.ModuleList] = LRU( - size=256 - ) # least-recently-used dict: a dict of fixed size that evicts old items - - self._dtype = None - - self.teacher_dir = Path(replacement_library_path).parent / "ckpts" / "teacher" self._model_config = None - self._embedding = None - self._ln_f = None - self._lm_head = None self._arbitrary_checkpoint_dir = None @staticmethod @@ -107,17 +75,6 @@ def _ensure_all_checkpoints_are_split(self) -> None: unsplit_checkpoints.append(checkpoint_dir) assert len(unsplit_checkpoints) == 0, f"Found unsplit checkpoints: {unsplit_checkpoints}" - @property - def dtype(self) -> torch.dtype: - if self._dtype is None: - ln_f = self.get_ln_f() - self._dtype = ln_f.weight.dtype - return self._dtype - - @property - def n_layer(self) -> int: - return self.model_config.get_num_hidden_layers() - @property def model_config(self) -> DeciLMConfig: if self._model_config is None: @@ -137,7 +94,7 @@ def create_model_config(self, layer_replacements: list[dict]): model_config.num_hidden_layers = len(block_configs) return model_config - def _get_arbitrary_block_checkpoint_paths(self): + def _get_arbitrary_non_block_checkpoint_paths(self): checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) subblocks_dir = checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME non_block_paths = [p for p in subblocks_dir.glob("*.safetensors") if "block_" not in p.name] @@ -161,7 +118,7 @@ def prepare_tmp_checkpoint_dir( ): arbitrary_checkpoint_dir = Path(self.get_arbitrary_checkpoint_dir()) - weight_paths = self._get_arbitrary_block_checkpoint_paths() + weight_paths = self._get_arbitrary_non_block_checkpoint_paths() for layer_replacement in layer_replacements: weight_paths += layer_replacement["weight_paths"] @@ -194,194 +151,11 @@ def load_model( model = load_and_shard_model(descriptor=self.descriptor, checkpoint_path=tmpdir) return model - def load_checkpoint(self, checkpoint_dir: str | Path) -> PreTrainedModel: - checkpoint_dir = Path(checkpoint_dir).resolve() - layer_replacements = self._locate_replacements_of_entire_checkpoint(checkpoint_dir) - model = self.load_model(layer_replacements) - return model - - def _locate_replacements_of_entire_checkpoint(self, checkpoint_dir: str | Path) -> list[dict]: - weight_paths_located = [] - layer_replacements = [] - for layer_replacement in self.replacement_library: - weight_paths = layer_replacement["weight_paths"] - weight_paths = [Path(p).absolute().resolve() for p in weight_paths] - layer_replacement["weight_paths"] = weight_paths - if len(weight_paths) > 0 and all( - p.is_relative_to(checkpoint_dir) for p in weight_paths - ): - layer_replacements.append(layer_replacement) - weight_paths_located.extend(weight_paths) - - all_block_weight_paths = [ - p - for p in list((checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).iterdir()) - if p.name not in ("embeddings.safetensors", "lm_head.safetensors") - ] - missing_paths = set(all_block_weight_paths) - set(weight_paths_located) - assert len(missing_paths) == 0, ( - f"Couldn't locate replacements for the entire checkpoint {checkpoint_dir}, missing weights: {missing_paths}" - ) - - dedupped_layer_replacements = [] - for weights_path in all_block_weight_paths: - replacements_with_path = [ - rep for rep in layer_replacements if weights_path in rep["weight_paths"] - ] - largets_replacement_with_path = max( - replacements_with_path, key=lambda rep: len(rep["weight_paths"]) - ) - if largets_replacement_with_path not in dedupped_layer_replacements: - dedupped_layer_replacements.append(largets_replacement_with_path) - - dedupped_layer_replacements = sort_replacements(dedupped_layer_replacements) - return dedupped_layer_replacements - - def get_block( - self, layer_replacement: dict, block_idx_in_replacement: int - ) -> DeciLMDecoderLayer | DeciLMMultiDecoderLayer: - if str(layer_replacement) not in self._loaded_replacements.keys(): - self._loaded_replacements[str(layer_replacement)] = self._load_layer_replacement( - layer_replacement - ) - module_list = self._loaded_replacements[str(layer_replacement)] - block = module_list[block_idx_in_replacement] - return block - - def _load_layer_replacement(self, layer_replacement: dict) -> nn.ModuleList: - state_dict = dict() - for weights_path in layer_replacement["weight_paths"]: - if weights_path.suffix == ".safetensors": - curr_state_dict = safe_load_file(weights_path) - elif weights_path.suffix == ".pth": - curr_state_dict = torch.load(weights_path, weights_only=True) - else: - raise ValueError(f"Unrecognized suffix of {weights_path=}") - for param_name in curr_state_dict.keys(): - assert param_name not in state_dict, ( - f"Duplicate entries for {param_name=} in {layer_replacement=}" - ) - state_dict.update(curr_state_dict) - - if len(state_dict) > 0: - block_indices = [ - int(re.findall(r"^model\.layers\.(\d+)\.", param_name)[0]) - for param_name in state_dict.keys() - ] - assert sorted(set(block_indices)) == list( - range(min(block_indices), max(block_indices) + 1) - ), ( - f"Block indices in loaded weight files must be consecutive, but found {sorted(set(block_indices))} in {layer_replacement=}" - ) - - min_block_idx = min(block_indices) - - state_dict = { - param_name.replace( - f"model.layers.{block_idx}.", f"{block_idx - min_block_idx}." - ): param_weight - for block_idx, (param_name, param_weight) in zip(block_indices, state_dict.items()) - } - - dtype = infer_weights_dtype(state_dict) - model_config = copy.deepcopy(self.model_config) - model_config.block_configs = layer_replacement["child_block_configs"] - model_config.num_hidden_layers = len(layer_replacement["child_block_configs"]) - - module_list = nn.ModuleList( - [ - ( - init_empty_module(DeciLMDecoderLayer, dtype, model_config, layer_idx) - if (block_config.parallel_blocks is None) - else init_empty_module(DeciLMMultiDecoderLayer, dtype, model_config, layer_idx) - ) - for layer_idx, block_config in enumerate(layer_replacement["child_block_configs"]) - ] - ) - - module_list.load_state_dict(state_dict, strict=True) - return module_list - - def _move_inactive_blocks_to_cpu(self, active_blocks: list[nn.Module]) -> None: - for module_list in self._loaded_replacements.values(): - for module in module_list: - if module not in active_blocks: - module.to("cpu") - - def get_embedding(self) -> nn.Embedding: - if self._embedding is None: - state_dict = { - "weight": self._get_arbitrary_non_block_param( - self.model_config.get_embedding_layer_name() + ".weight" - ) - } - self._embedding = init_module_with_state_dict( - state_dict, - nn.Embedding, - num_embeddings=self.model_config.vocab_size, - embedding_dim=self.model_config.hidden_size, - ) - return self._embedding - - def get_ln_f(self) -> DeciLMRMSNorm: - if self._ln_f is None: - state_dict = { - "weight": self._get_arbitrary_non_block_param( - self.model_config.get_final_layer_norm_layer_name() + ".weight" - ) - } - self._ln_f = init_module_with_state_dict( - state_dict, - DeciLMRMSNorm, - hidden_size=self.model_config.hidden_size, - eps=self.model_config.rms_norm_eps, - ) - return self._ln_f - - def get_lm_head(self) -> nn.Linear: - if self._lm_head is None: - state_dict = { - "weight": self._get_arbitrary_non_block_param( - self.model_config.get_lm_head_layer_name() + ".weight" - ) - } - self._lm_head = init_module_with_state_dict( - state_dict, - LMHead, - out_features=self.model_config.vocab_size, - in_features=self.model_config.hidden_size, - bias=False, - ) - return self._lm_head - - def _get_arbitrary_non_block_param(self, param_name: str) -> torch.Tensor: - checkpoint_dir = self.get_arbitrary_checkpoint_dir() - if ( - is_in_safetensors_format(checkpoint_dir) - or (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).exists() - ): - partial_state_dict = load_sharded_state_dict(checkpoint_dir, [param_name]) - return partial_state_dict[param_name] - - non_block_pth_path = checkpoint_dir / PTH_SUBBLOCKS_DIR_NAME / f"non_block.pth" - assert non_block_pth_path.exists(), _error_message_ensure_split(checkpoint_dir) - non_block_state_dict = torch.load(non_block_pth_path) - return non_block_state_dict[param_name] - def get_arbitrary_checkpoint_dir(self) -> Path: if self._arbitrary_checkpoint_dir is None: self._arbitrary_checkpoint_dir = self._get_arbitrary_checkpoint_dir() return self._arbitrary_checkpoint_dir - def get_teacher_dir(self) -> Path: - return self.teacher_dir - - def get_teacher_lm_head_path(self) -> Path: - return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "lm_head.safetensors" - - def get_teacher_embedding_path(self) -> Path: - return self.get_teacher_dir() / SAFETENSORS_SUBBLOCKS_DIR_NAME / "embeddings.safetensors" - def _get_arbitrary_checkpoint_dir(self) -> Path: for layer_replacement in self.replacement_library: weight_paths = layer_replacement["weight_paths"] @@ -396,27 +170,3 @@ def _get_all_checkpoint_dirs(self) -> list[Path]: checkpoint_dir = weights_path_to_checkpoint_dir(weights_path) checkpoint_dirs.add(checkpoint_dir) return list(checkpoint_dirs) - - -def _error_message_ensure_split(checkpoint_dir: Path) -> str: - return ( - f"Encountered unsplit checkpoint dir '{checkpoint_dir}', " - f"please call `ensure_all_checkpoints_are_split`" - ) - - -def _get_owned_block_indexes(n_layer: int) -> list[int]: - last_process_blocks = np.array([n_layer - 1]) # less params in last gpu, leave room for logits - - if dist.size() == 1: - # Only one process: assign everything (including the "last process" block) to rank 0 - owned_block_indexes_per_process = [ - np.concatenate([np.arange(n_layer - 1), last_process_blocks]) - ] - else: - # Multiple processes: split n_layer-1 blocks, reserve the last for "last process" - owned_block_indexes_per_process = np.array_split(range(n_layer - 1), dist.size() - 1) - owned_block_indexes_per_process.append(last_process_blocks) - - owned_block_indexes = owned_block_indexes_per_process[dist.rank()].tolist() - return owned_block_indexes diff --git a/modelopt/torch/puzzletron/replacement_library/replacement_utils.py b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py index 68ba0b5fc36..269e5e63ea7 100644 --- a/modelopt/torch/puzzletron/replacement_library/replacement_utils.py +++ b/modelopt/torch/puzzletron/replacement_library/replacement_utils.py @@ -21,8 +21,9 @@ from copy import deepcopy from pathlib import Path +from transformers import PretrainedConfig + from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.block_config import BlockConfig -from modelopt.torch.puzzletron.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig from modelopt.torch.puzzletron.mip.utils import sort_replacements @@ -73,7 +74,7 @@ def weights_path_to_checkpoint_dir(weights_path: Path) -> Path: def replacement_is_teacher( layer_replacement: dict, - teacher_model_config: DeciLMConfig, + teacher_model_config: PretrainedConfig, teacher_checkpoint_dir: Path, ) -> bool: paths_all_teacher = all( @@ -86,7 +87,7 @@ def replacement_is_teacher( def is_replacement_identical_to_teacher( layer_replacement: dict, - teacher_model_config: DeciLMConfig, + teacher_model_config: PretrainedConfig, ) -> bool: if len(layer_replacement["parent_layer_indices"]) == 1: block_idx = layer_replacement["parent_layer_indices"][0] @@ -109,7 +110,7 @@ def is_replacement_identical_to_teacher( def split_replacements_to_teacher_and_student( replacements: list[dict], - teacher_model_config: DeciLMConfig, + teacher_model_config: PretrainedConfig, teacher_checkpoint_dir: Path, ) -> tuple[list[dict], list[dict]]: teacher_replacements, student_replacements = [], []