diff --git a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json index 737edcc413..b3fd2e26db 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json +++ b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json @@ -3,8 +3,8 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "maxabs_hw", - "whitelist": {"types": [], "names": ["gate","w1","w3","w2"]}, - "blacklist": {"types": [], "names": [ + "allowlist": {"types": [], "names": ["gate","w1","w3","w2"]}, + "blocklist": {"types": [], "names": [ "model.layers.1.block_sparse_moe.experts.(3|4).w2", "model.layers.[29-31].block_sparse_moe.experts.[0-7].w2" ]}, diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 0d50470532..6d2dd992c9 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -584,8 +584,9 @@ def generate( assert self.config.model_type in [ "llama", "mistral", + "mixtral", "falcon", - ], "reuse_cache only supported by llama, mistral and falcon at the moment" + ], "reuse_cache only supported by llama, mistral, falcon and mixtral at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 6dc40a73bf..b165d3edbc 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -49,7 +49,10 @@ GaudiMistralDecoderLayer, GaudiMistralForCausalLM, GaudiMistralModel, + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, + GaudiMixtralModel, GaudiMptForCausalLM, GaudiMptModel, GaudiOPTForCausalLM, @@ -104,10 +107,7 @@ gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, gaudi_mistral_rmsnorm_forward, - gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, gaudi_mpt_attention_forward, gaudi_mpt_block_forward, @@ -353,10 +353,10 @@ def adapt_transformers_to_gaudi(): # Optimization for mixtral on Gaudi transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM - transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = gaudi_mixtral_model_forward - transformers.models.mixtral.modeling_mixtral.MixtralAttention.forward = gaudi_mixtral_attention_forward + transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel + transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_sparse_moe_forward - transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = gaudi_mixtral_decoder_layer_forward + transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward # Optimization for speecht5 on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 1582d3f09e..da0bdcc22e 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -89,11 +89,11 @@ gaudi_mistral_rmsnorm_forward, ) from .mixtral import ( + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, - gaudi_mixtral_attention_forward, + GaudiMixtralModel, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) from .modeling_all_models import ( diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index fd1829bbe2..d95c0c2cdd 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -1,8 +1,8 @@ from .modeling_mixtral import ( + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, - gaudi_mixtral_attention_forward, + GaudiMixtralModel, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index c6f5f51ab7..68ba890fe6 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -20,7 +20,9 @@ """PyTorch Mixtral model.""" +import contextlib import math +import os import warnings from typing import List, Optional, Tuple, Union @@ -29,7 +31,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss -from transformers.cache_utils import Cache, DynamicCache +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, @@ -37,9 +39,14 @@ ) from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralConfig, + MixtralDecoderLayer, MixtralForCausalLM, + MixtralModel, apply_rotary_pos_emb, load_balancing_loss_func, + repeat_kv, ) from transformers.utils import logging @@ -62,33 +69,23 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None -logger = logging.get_logger(__name__) +try: + from habana_frameworks.torch.hpu import sdp_kernel + SDPContext = True +except ImportError: + SDPContext = False -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.dtype == torch.float8_e4m3fn: - from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 - - cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) +logger = logging.get_logger(__name__) def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids - ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) @@ -143,24 +140,97 @@ def gaudi_mixtral_repeat_kv( new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) query_states = query_states.reshape(new_q_shape) - if attention_mask is not None: - # Add groups dim and set to 1 - attention_mask = attention_mask.unsqueeze(1) - return query_states, key_states, value_states, attention_mask +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, invAttnHead=None): + return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) + + +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +class ScaledDotProductAttention(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + # self.num_heads = config.num_attention_heads + # self.num_kv_groups = config.num_attention_heads // config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.bmm1 = Matmul() + self.bmm2 = Matmul() + self.softmax = Softmax() + + def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(self.head_dim) + invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") + + if is_causal: + assert attn_mask is None + attn_bias = torch.zeros(L, S, dtype=query.dtype) + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + + attn_weight = self.bmm1(query, key.transpose(-2, -1)) + + attn_weight += attn_mask + attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return self.bmm2(attn_weight, value) + + +class NaiveFlashAttention(nn.Module): + @staticmethod + def forward(q, k, v, mask, causal, q_bucket_size): + """ + Support long sequence prompt + """ + q_len = q.size(-2) + q_tiles = (q_len // q_bucket_size) if (q_len % q_bucket_size == 0) else math.ceil(q_len / q_bucket_size) + q_padding = q_tiles * q_bucket_size - q_len + q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) + if mask is not None: + mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) + attn_output = torch.zeros_like(q) + + row_tiles = zip( + q.split(q_bucket_size, dim=-2), + attn_output.split(q_bucket_size, dim=-2), + mask.split(q_bucket_size, dim=-2), + ) + + for row_q, row_o, row_mask in row_tiles: + row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None)) + + if q_padding != 0: + attn_output = attn_output[:, :, :-q_padding, :] + + return attn_output + + class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() self.cache = None self.inp_seq_len = -1 - def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + def allocate(self, inp_seq_len, dtype, device, shape): if self.cache is None or self.cache.shape != shape: self.inp_seq_len = inp_seq_len - if kv_cache_fp8: - dtype = torch.float8_e4m3fn self.cache = torch.zeros(shape, dtype=dtype, device=device) else: assert ( @@ -168,114 +238,191 @@ def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" self.cache.fill_(0) + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + def get_shape(self): if self.cache is None: return None return self.cache.shape def forward(self, cur, dim, idx): - return update(self.cache, cur, dim, idx, self.inp_seq_len) - - -def gaudi_mixtral_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Copied from MixtralAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - add new args token_idx - - optimize KV cache - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + +class GaudiMixtralAttention(MixtralAttention): + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + if os.getenv("QUANT_CONFIG", ""): + self.sdpa = ScaledDotProductAttention(config) + + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.bucket_size = 1024 # 1024 + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from MixtralAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args reuse_cache + - add new args flash_attention_recompute + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - kv_seq_len = past_key_value.key_cache[self.layer_idx].shape[-2] + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + + if use_cache: + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + past_key_value = None + + if FusedSDPA: + if not self.training and q_len == key_states.size(-2) and q_len > 8192: + htcore.mark_step() + attn_output = NaiveFlashAttention.forward( + query_states, + key_states, + value_states, + attention_mask, + False, + self.bucket_size, + ) + htcore.mark_step() else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if os.getenv("QUANT_CONFIG", ""): + # WA for GQA optimization is not supported currently + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output = self.sdpa(query_states, key_states, value_states, attention_mask, 0.0, False, None) + else: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) - if FusedSDPA: - import habana_frameworks.torch.hpu as ht + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if q_len == 1: - # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) - else: - # first token - with ht.sdp_kernel(enable_recompute=False): # inference: flash_attention_recompute = False - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) - else: - query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(2) + attn_weights = attn_weights + attention_mask - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(2) - attn_weights = attn_weights + attention_mask + attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if not output_attentions: + attn_weights = None - attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value +MIXTRAL_ATTENTION_CLASSES = { + "eager": GaudiMixtralAttention, +} def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -327,234 +474,271 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> return final_hidden_states, router_logits -def gaudi_mixtral_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, - use_cache: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - add new args token_idx - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) +class GaudiMixtralDecoderLayer(MixtralDecoderLayer): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size - htcore.mark_step() - residual = hidden_states + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - hidden_states = self.input_layernorm(hidden_states) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) - hidden_states = residual + hidden_states - htcore.mark_step() - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.block_sparse_moe(hidden_states) - hidden_states = residual + hidden_states - htcore.mark_step() - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs - - -def gaudi_mixtral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, MoeModelOutputWithPast]: - """ - Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 - The only differences are: - - add new args token_idx - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args token_idx + - add new args reuse_cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + residual = hidden_states - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + hidden_states = self.input_layernorm(hidden_states) - past_key_values_length = 0 + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) - if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + - if self.config._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, +class GaudiMixtralModel(MixtralModel): + def __init__(self, config: MixtralConfig): + super().__init__(config) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + """ + Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 + The only differences are: + - add new args token_idx + - add new args reuse_cache + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache - hidden_states = inputs_embeds + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) + past_key_values_length = 0 + use_new_cache = False # Ignoring new Cache path for HPU if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values is not None and use_cache: + if reuse_cache: + past_key_values_length = past_key_values[0][0][2] + else: + if use_new_cache: + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length() + else: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - token_idx=token_idx, + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if not use_new_cache else None + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) - if output_router_logits: - all_router_logits += (layer_outputs[-1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) + hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) class GaudiMixtralForCausalLM(MixtralForCausalLM): @@ -567,6 +751,10 @@ class GaudiMixtralForCausalLM(MixtralForCausalLM): - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + self.kv_cache_len = max_seq_len + def forward( self, input_ids: torch.LongTensor = None, @@ -581,6 +769,9 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = None, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -605,6 +796,9 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = outputs[0] @@ -654,11 +848,14 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 + reuse_cache = kwargs.get("reuse_cache") token_idx = kwargs.get("token_idx", None) - # Omit tokens covered by past_key_values if past_key_values is not None: - if token_idx is None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens @@ -686,8 +883,10 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -713,6 +912,9 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": reuse_cache, + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs