diff --git a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json index 737edcc413..b3fd2e26db 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json +++ b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json @@ -3,8 +3,8 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "maxabs_hw", - "whitelist": {"types": [], "names": ["gate","w1","w3","w2"]}, - "blacklist": {"types": [], "names": [ + "allowlist": {"types": [], "names": ["gate","w1","w3","w2"]}, + "blocklist": {"types": [], "names": [ "model.layers.1.block_sparse_moe.experts.(3|4).w2", "model.layers.[29-31].block_sparse_moe.experts.[0-7].w2" ]}, diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 426ed4e687..98b65ae289 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -604,7 +604,8 @@ def generate( "llama", "mistral", "falcon", - ], "reuse_cache only supported by llama, mistral and falcon at the moment" + "mixtral", + ], "reuse_cache only supported by llama, mistral, falcon and mixtral at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 88d662e655..5e1cbc290d 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -56,6 +56,7 @@ GaudiMixtralAttention, GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, + GaudiMixtralModel, GaudiMptForCausalLM, GaudiMptModel, GaudiOPTForCausalLM, @@ -120,7 +121,6 @@ gaudi_llama_rmsnorm_forward, gaudi_mistral_rmsnorm_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, gaudi_mpt_attention_forward, gaudi_mpt_block_forward, @@ -392,7 +392,7 @@ def adapt_transformers_to_gaudi(): # Optimization for mixtral on Gaudi transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM - transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = gaudi_mixtral_model_forward + transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_sparse_moe_forward transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 3566e7736b..351b482d5a 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -101,9 +101,9 @@ GaudiMixtralAttention, GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, + GaudiMixtralModel, MixtralConfig, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) from .modeling_all_models import ( diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index 4687fe0e20..ab34977c37 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -3,7 +3,7 @@ GaudiMixtralAttention, GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, + GaudiMixtralModel, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 872a8ce0dd..3e22f49629 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -29,7 +29,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss -from transformers.cache_utils import Cache, DynamicCache +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, @@ -40,6 +40,7 @@ MixtralAttention, MixtralDecoderLayer, MixtralForCausalLM, + MixtralModel, apply_rotary_pos_emb, load_balancing_loss_func, ) @@ -74,25 +75,6 @@ logger = logging.get_logger(__name__) -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.dtype == torch.float8_e4m3fn: - from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 - - cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) - - def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: return FusedRoPE.apply( @@ -165,11 +147,9 @@ def __init__(self): self.cache = None self.inp_seq_len = -1 - def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + def allocate(self, inp_seq_len, dtype, device, shape): if self.cache is None or self.cache.shape != shape: self.inp_seq_len = inp_seq_len - if kv_cache_fp8: - dtype = torch.float8_e4m3fn self.cache = torch.zeros(shape, dtype=dtype, device=device) else: assert ( @@ -177,19 +157,39 @@ def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" self.cache.fill_(0) + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + def get_shape(self): if self.cache is None: return None return self.cache.shape def forward(self, cur, dim, idx): - return update(self.cache, cur, dim, idx, self.inp_seq_len) + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) class GaudiMixtralAttention(MixtralAttention): def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self._init_rope() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) def _init_rope(self): """ @@ -221,6 +221,13 @@ def _init_rope(self): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + def forward( self, hidden_states: torch.Tensor, @@ -230,6 +237,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -237,6 +247,9 @@ def forward( The only differences are: - add new args token_idx - optimize KV cache + - add new args reuse_cache + - add new args flash_attention_recompute + - add new args cache_idx """ if "padding_mask" in kwargs: warnings.warn( @@ -260,29 +273,44 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - kv_seq_len = past_key_value.key_cache[self.layer_idx].shape[-2] + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] - else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_cache: + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None if FusedSDPA: import habana_frameworks.torch.hpu as ht @@ -304,7 +332,7 @@ def forward( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.norm_factor if attention_mask is not None: attention_mask = attention_mask.unsqueeze(2) @@ -382,6 +410,9 @@ def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = GaudiMixtralAttention(config, layer_idx) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def forward( self, hidden_states: torch.Tensor, @@ -392,12 +423,18 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py The only differences are: - add new args token_idx + - add new args reuse_cache + - add new args flash_attention_recompute + - add new args cache_idx """ if "padding_mask" in kwargs: warnings.warn( @@ -418,6 +455,9 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = residual + hidden_states htcore.mark_step() @@ -443,173 +483,198 @@ def forward( return outputs -def gaudi_mixtral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, MoeModelOutputWithPast]: - """ - Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 - The only differences are: - - add new args token_idx - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = 0 +class GaudiMixtralModel(MixtralModel): + def __init__(self, config: MixtralConfig): + super().__init__(config) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) - if self.config._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + """ + Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 + The only differences are: + - add new args token_idx + - add new args reuse_cache + - add new args flash_attention_recompute + - add new args cache_idx + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache - hidden_states = inputs_embeds + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) + past_key_values_length = 0 + use_new_cache = False # Ignoring new Cache path for HPU if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values is not None and use_cache: + if reuse_cache: + past_key_values_length = past_key_values[0][0][2] + else: + if use_new_cache: + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length() + else: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - token_idx=token_idx, + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, ) - hidden_states = layer_outputs[0] + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if not use_new_cache else None + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, + ) - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_router_logits: - all_router_logits += (layer_outputs[-1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) + hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) class GaudiMixtralForCausalLM(MixtralForCausalLM): @@ -622,6 +687,10 @@ class GaudiMixtralForCausalLM(MixtralForCausalLM): - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + self.kv_cache_len = max_seq_len + def forward( self, input_ids: torch.LongTensor = None, @@ -636,6 +705,9 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = None, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -660,6 +732,9 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = outputs[0] @@ -709,11 +784,15 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 + reuse_cache = kwargs.get("reuse_cache") token_idx = kwargs.get("token_idx", None) # Omit tokens covered by past_key_values if past_key_values is not None: - if token_idx is None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens @@ -741,8 +820,10 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -768,6 +849,9 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": reuse_cache, + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 1677319e8f..4deaa72bcf 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -37,6 +37,7 @@ ], "fp8": [ ("tiiuae/falcon-180B", 52.85086442722326), + ("mistralai/Mixtral-8x7B-v0.1", 39.26845661768185), ("meta-llama/Llama-2-7b-hf", 0.0), ("meta-llama/Llama-2-70b-hf", 0.0), ],