From 2394afd0a9376c021f80e27309fa6ae59e71be4e Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:41:20 +0000 Subject: [PATCH 1/6] Add linear and dynamic RoPE to Mistral and Mixtral --- optimum/habana/transformers/modeling_utils.py | 12 +- .../habana/transformers/models/__init__.py | 6 +- .../transformers/models/mistral/__init__.py | 1 + .../models/mistral/modeling_mistral.py | 33 ++ .../transformers/models/mixtral/__init__.py | 5 +- .../models/mixtral/modeling_mixtral.py | 322 ++++++++++-------- 6 files changed, 236 insertions(+), 143 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 6dc40a73bf..5ef0478cd0 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -49,12 +49,16 @@ GaudiMistralDecoderLayer, GaudiMistralForCausalLM, GaudiMistralModel, + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, GaudiMptForCausalLM, GaudiMptModel, GaudiOPTForCausalLM, GaudiOPTLearnedPositionalEmbedding, GaudiPhiForCausalLM, + MistralConfig, + MixtralConfig, _gaudi_wav2vec2_compute_mask_indices, _gaudi_wav2vec2_mask_hidden_states, gaudi_albert_forward, @@ -104,9 +108,7 @@ gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, gaudi_mistral_rmsnorm_forward, - gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, gaudi_mpt_attention_forward, @@ -331,6 +333,7 @@ def adapt_transformers_to_gaudi(): transformers.models.mistral.modeling_mistral.MistralDecoderLayer = GaudiMistralDecoderLayer transformers.models.mistral.modeling_mistral.MistralModel = GaudiMistralModel transformers.models.mistral.modeling_mistral.MistralRMSNorm.forward = gaudi_mistral_rmsnorm_forward + transformers.models.mistral.configuration_mistral.MistralConfig = MistralConfig # Optimization for phi on Gaudi transformers.models.phi.modeling_phi.PhiForCausalLM = GaudiPhiForCausalLM @@ -352,12 +355,13 @@ def adapt_transformers_to_gaudi(): transformers.models.blip.BlipForConditionalGeneration.generate = gaudi_BlipForConditionalGeneration_generate # Optimization for mixtral on Gaudi + transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = gaudi_mixtral_model_forward - transformers.models.mixtral.modeling_mixtral.MixtralAttention.forward = gaudi_mixtral_attention_forward transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_sparse_moe_forward - transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = gaudi_mixtral_decoder_layer_forward + transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward + transformers.models.mixtral.configuration_mixtral.MixtralConfig = MixtralConfig # Optimization for speecht5 on Gaudi transformers.models.speecht5.modeling_speecht5.SpeechT5Decoder.forward = gaudi_SpeechT5Decoder_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 1582d3f09e..97d4e5a45f 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -86,13 +86,15 @@ GaudiMistralDecoderLayer, GaudiMistralForCausalLM, GaudiMistralModel, + MistralConfig, gaudi_mistral_rmsnorm_forward, ) from .mixtral import ( + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, - gaudi_mixtral_attention_forward, + MixtralConfig, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mistral/__init__.py b/optimum/habana/transformers/models/mistral/__init__.py index 192c267791..a4cdf47bff 100644 --- a/optimum/habana/transformers/models/mistral/__init__.py +++ b/optimum/habana/transformers/models/mistral/__init__.py @@ -1,3 +1,4 @@ +from .configuration_mistral import MistralConfig from .modeling_mistral import ( GaudiMistralAttention, GaudiMistralDecoderLayer, diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index cf5fa6f2c0..4198e1fc43 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -42,6 +42,11 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) +from ..llama.modeling_llama import ( + GaudiLlamaDynamicNTKScalingRotaryEmbedding, + GaudiLlamaLinearScalingRotaryEmbedding, + GaudiLlamaRotaryEmbedding, +) try: @@ -128,6 +133,34 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.past_key = None self.past_value = None + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = GaudiLlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = GaudiLlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = GaudiLlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def allocate_kv_cache(self, batch_size, seq_len): kv_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index fd1829bbe2..4687fe0e20 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -1,8 +1,9 @@ +from .configuration_mixtral import MixtralConfig from .modeling_mixtral import ( + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, - gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index c6f5f51ab7..8fd0919b00 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -37,12 +37,21 @@ ) from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, MixtralForCausalLM, apply_rotary_pos_emb, load_balancing_loss_func, ) from transformers.utils import logging +from ..llama.modeling_llama import ( + GaudiLlamaDynamicNTKScalingRotaryEmbedding, + GaudiLlamaLinearScalingRotaryEmbedding, + GaudiLlamaRotaryEmbedding, +) +from .configuration_mixtral import MixtralConfig + try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -177,105 +186,143 @@ def forward(self, cur, dim, idx): return update(self.cache, cur, dim, idx, self.inp_seq_len) -def gaudi_mixtral_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Copied from MixtralAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - add new args token_idx - - optimize KV cache - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) +class GaudiMixtralAttention(MixtralAttention): + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self._init_rope() - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = GaudiLlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, ) - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - kv_seq_len = past_key_value.key_cache[self.layer_idx].shape[-2] else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = GaudiLlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = GaudiLlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) - else: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if FusedSDPA: - import habana_frameworks.torch.hpu as ht + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - if q_len == 1: - # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + token_idx: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from MixtralAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args token_idx + - optimize KV cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if token_idx is not None: + if 0 <= self.layer_idx < len(past_key_value.key_cache): + kv_seq_len = past_key_value.key_cache[self.layer_idx].shape[-2] + else: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + if token_idx is not None: + if 0 <= self.layer_idx < len(past_key_value.key_cache): + past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) + past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + past_key_value.key_cache.append(key_states) + past_key_value.value_cache.append(value_states) + else: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if FusedSDPA: + import habana_frameworks.torch.hpu as ht + + if q_len == 1: + # next token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # first token + with ht.sdp_kernel(enable_recompute=False): # inference: flash_attention_recompute = False + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: - # first token - with ht.sdp_kernel(enable_recompute=False): # inference: flash_attention_recompute = False - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) - else: - query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) + query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) - if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(2) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(2) + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() + attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None + if not output_attentions: + attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -327,65 +374,70 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> return final_hidden_states, router_logits -def gaudi_mixtral_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, - use_cache: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - add new args token_idx - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) +class GaudiMixtralDecoderLayer(MixtralDecoderLayer): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = GaudiMixtralAttention(config, layer_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args token_idx + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) - htcore.mark_step() - residual = hidden_states + htcore.mark_step() + residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) - hidden_states = residual + hidden_states - htcore.mark_step() + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + ) + hidden_states = residual + hidden_states + htcore.mark_step() - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.block_sparse_moe(hidden_states) - hidden_states = residual + hidden_states - htcore.mark_step() + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + htcore.mark_step() - outputs = (hidden_states,) + outputs = (hidden_states,) - if output_attentions: - outputs += (self_attn_weights,) + if output_attentions: + outputs += (self_attn_weights,) - if use_cache: - outputs += (present_key_value,) + if use_cache: + outputs += (present_key_value,) - if output_router_logits: - outputs += (router_logits,) + if output_router_logits: + outputs += (router_logits,) - return outputs + return outputs def gaudi_mixtral_model_forward( From b004ff34e98de43bf3eff6467b52d132602fd045 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:47:47 +0000 Subject: [PATCH 2/6] Add references --- .../models/mistral/configuration_mistral.py | 78 +++++++++++++++++ .../models/mistral/modeling_mistral.py | 3 + .../models/mixtral/configuration_mixtral.py | 86 +++++++++++++++++++ .../models/mixtral/modeling_mixtral.py | 3 + 4 files changed, 170 insertions(+) create mode 100644 optimum/habana/transformers/models/mistral/configuration_mistral.py create mode 100644 optimum/habana/transformers/models/mixtral/configuration_mixtral.py diff --git a/optimum/habana/transformers/models/mistral/configuration_mistral.py b/optimum/habana/transformers/models/mistral/configuration_mistral.py new file mode 100644 index 0000000000..6e5968a00f --- /dev/null +++ b/optimum/habana/transformers/models/mistral/configuration_mistral.py @@ -0,0 +1,78 @@ +from transformers.models.mistral.configuration_mistral import MistralConfig + + +class MistralConfig(MistralConfig): + """ + Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mistral/configuration_mistral.py#L29 + Changes: + - add `rope_scaling` and `_rope_scaling_validation` (inspired from Llama) + """ + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + attention_dropout=0.0, + rope_scaling=None, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + pad_token_id, + bos_token_id, + eos_token_id, + tie_word_embeddings, + rope_theta, + sliding_window, + attention_dropout, + **kwargs, + ) + + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + def _rope_scaling_validation(self): + """ + Taken from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/configuration_llama.py#L172 + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 4198e1fc43..f2a153d5d4 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -136,6 +136,9 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self._init_rope() def _init_rope(self): + """ + Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L294 + """ if self.config.rope_scaling is None: self.rotary_emb = GaudiLlamaRotaryEmbedding( self.head_dim, diff --git a/optimum/habana/transformers/models/mixtral/configuration_mixtral.py b/optimum/habana/transformers/models/mixtral/configuration_mixtral.py new file mode 100644 index 0000000000..e878889669 --- /dev/null +++ b/optimum/habana/transformers/models/mixtral/configuration_mixtral.py @@ -0,0 +1,86 @@ +from transformers.models.mixtral.configuration_mixtral import MixtralConfig + + +class MixtralConfig(MixtralConfig): + """ + Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mixtral/configuration_mixtral.py#L28 + Changes: + - add `rope_scaling` and `_rope_scaling_validation` (inspired from Llama) + """ + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + rope_scaling=None, + **kwargs, + ): + super().__init__( + vocab_size, + hidden_size, + intermediate_size, + num_hidden_layers, + num_attention_heads, + num_key_value_heads, + hidden_act, + max_position_embeddings, + initializer_range, + rms_norm_eps, + use_cache, + pad_token_id, + bos_token_id, + eos_token_id, + tie_word_embeddings, + rope_theta, + sliding_window, + attention_dropout, + num_experts_per_tok, + num_local_experts, + output_router_logits, + router_aux_loss_coef, + **kwargs, + ) + + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + def _rope_scaling_validation(self): + """ + Taken from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/configuration_llama.py#L172 + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 8fd0919b00..872a8ce0dd 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -192,6 +192,9 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self._init_rope() def _init_rope(self): + """ + Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L294 + """ if self.config.rope_scaling is None: self.rotary_emb = GaudiLlamaRotaryEmbedding( self.head_dim, From 43d194d06a279942865ff14f2ac7bd57d3ae43ad Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Wed, 17 Apr 2024 15:28:20 +0000 Subject: [PATCH 3/6] Support mixtral kv-cache reuse and remove kv_cache_fp8 --- .../habana/transformers/generation/utils.py | 3 +- optimum/habana/transformers/modeling_utils.py | 4 +- .../habana/transformers/models/__init__.py | 2 +- .../transformers/models/mixtral/__init__.py | 2 +- .../models/mixtral/modeling_mixtral.py | 466 +++++++++++------- 5 files changed, 281 insertions(+), 196 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 55683ebf66..eceb64da53 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -597,7 +597,8 @@ def generate( "llama", "mistral", "falcon", - ], "reuse_cache only supported by llama, mistral and falcon at the moment" + "mixtral", + ], "reuse_cache only supported by llama, mistral, falcon and mixtral at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 5ef0478cd0..28d2ca9674 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -52,6 +52,7 @@ GaudiMixtralAttention, GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, + GaudiMixtralModel, GaudiMptForCausalLM, GaudiMptModel, GaudiOPTForCausalLM, @@ -109,7 +110,6 @@ gaudi_llama_rmsnorm_forward, gaudi_mistral_rmsnorm_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, gaudi_mpt_attention_forward, gaudi_mpt_block_forward, @@ -357,7 +357,7 @@ def adapt_transformers_to_gaudi(): # Optimization for mixtral on Gaudi transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM - transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = gaudi_mixtral_model_forward + transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_sparse_moe_forward transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 97d4e5a45f..05d8e89ca2 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -93,9 +93,9 @@ GaudiMixtralAttention, GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, + GaudiMixtralModel, MixtralConfig, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) from .modeling_all_models import ( diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index 4687fe0e20..ab34977c37 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -3,7 +3,7 @@ GaudiMixtralAttention, GaudiMixtralDecoderLayer, GaudiMixtralForCausalLM, + GaudiMixtralModel, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 872a8ce0dd..3e22f49629 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -29,7 +29,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss -from transformers.cache_utils import Cache, DynamicCache +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, @@ -40,6 +40,7 @@ MixtralAttention, MixtralDecoderLayer, MixtralForCausalLM, + MixtralModel, apply_rotary_pos_emb, load_balancing_loss_func, ) @@ -74,25 +75,6 @@ logger = logging.get_logger(__name__) -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.dtype == torch.float8_e4m3fn: - from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 - - cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) - - def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: return FusedRoPE.apply( @@ -165,11 +147,9 @@ def __init__(self): self.cache = None self.inp_seq_len = -1 - def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + def allocate(self, inp_seq_len, dtype, device, shape): if self.cache is None or self.cache.shape != shape: self.inp_seq_len = inp_seq_len - if kv_cache_fp8: - dtype = torch.float8_e4m3fn self.cache = torch.zeros(shape, dtype=dtype, device=device) else: assert ( @@ -177,19 +157,39 @@ def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" self.cache.fill_(0) + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + def get_shape(self): if self.cache is None: return None return self.cache.shape def forward(self, cur, dim, idx): - return update(self.cache, cur, dim, idx, self.inp_seq_len) + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) class GaudiMixtralAttention(MixtralAttention): def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self._init_rope() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) def _init_rope(self): """ @@ -221,6 +221,13 @@ def _init_rope(self): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) + def forward( self, hidden_states: torch.Tensor, @@ -230,6 +237,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -237,6 +247,9 @@ def forward( The only differences are: - add new args token_idx - optimize KV cache + - add new args reuse_cache + - add new args flash_attention_recompute + - add new args cache_idx """ if "padding_mask" in kwargs: warnings.warn( @@ -260,29 +273,44 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - kv_seq_len = past_key_value.key_cache[self.layer_idx].shape[-2] + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] - else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if use_cache: + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None if FusedSDPA: import habana_frameworks.torch.hpu as ht @@ -304,7 +332,7 @@ def forward( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.norm_factor if attention_mask is not None: attention_mask = attention_mask.unsqueeze(2) @@ -382,6 +410,9 @@ def __init__(self, config: MixtralConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = GaudiMixtralAttention(config, layer_idx) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def forward( self, hidden_states: torch.Tensor, @@ -392,12 +423,18 @@ def forward( output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py The only differences are: - add new args token_idx + - add new args reuse_cache + - add new args flash_attention_recompute + - add new args cache_idx """ if "padding_mask" in kwargs: warnings.warn( @@ -418,6 +455,9 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = residual + hidden_states htcore.mark_step() @@ -443,173 +483,198 @@ def forward( return outputs -def gaudi_mixtral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, MoeModelOutputWithPast]: - """ - Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 - The only differences are: - - add new args token_idx - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = 0 +class GaudiMixtralModel(MixtralModel): + def __init__(self, config: MixtralConfig): + super().__init__(config) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) - if self.config._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + """ + Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 + The only differences are: + - add new args token_idx + - add new args reuse_cache + - add new args flash_attention_recompute + - add new args cache_idx + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache - hidden_states = inputs_embeds + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) + past_key_values_length = 0 + use_new_cache = False # Ignoring new Cache path for HPU if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values is not None and use_cache: + if reuse_cache: + past_key_values_length = past_key_values[0][0][2] + else: + if use_new_cache: + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length() + else: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - token_idx=token_idx, + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, ) - hidden_states = layer_outputs[0] + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if not use_new_cache else None + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, + ) - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if output_router_logits: - all_router_logits += (layer_outputs[-1],) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) + hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) class GaudiMixtralForCausalLM(MixtralForCausalLM): @@ -622,6 +687,10 @@ class GaudiMixtralForCausalLM(MixtralForCausalLM): - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + self.kv_cache_len = max_seq_len + def forward( self, input_ids: torch.LongTensor = None, @@ -636,6 +705,9 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = None, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -660,6 +732,9 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = outputs[0] @@ -709,11 +784,15 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 + reuse_cache = kwargs.get("reuse_cache") token_idx = kwargs.get("token_idx", None) # Omit tokens covered by past_key_values if past_key_values is not None: - if token_idx is None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens @@ -741,8 +820,10 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -768,6 +849,9 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": reuse_cache, + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs From a05b986834ef2fd6801d15a4909998fb98fe3df6 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Wed, 17 Apr 2024 16:03:46 +0000 Subject: [PATCH 4/6] Modify white/black list to allow/block list --- .../quantization_config/maxabs_quant_mixtral.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json index 737edcc413..b3fd2e26db 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json +++ b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json @@ -3,8 +3,8 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "maxabs_hw", - "whitelist": {"types": [], "names": ["gate","w1","w3","w2"]}, - "blacklist": {"types": [], "names": [ + "allowlist": {"types": [], "names": ["gate","w1","w3","w2"]}, + "blocklist": {"types": [], "names": [ "model.layers.1.block_sparse_moe.experts.(3|4).w2", "model.layers.[29-31].block_sparse_moe.experts.[0-7].w2" ]}, From af8e9cddb0072de80aa71c1a0b2d7cef63b0c30e Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Tue, 23 Apr 2024 02:30:40 +0000 Subject: [PATCH 5/6] add mixtral fp8 test case --- tests/test_text_generation_example.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 48150d635c..4819ef8e87 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -29,6 +29,7 @@ ], "fp8": [ ("tiiuae/falcon-180B", 52.85086442722326), + ("mistralai/Mixtral-8x7B-v0.1", 39.26845661768185), ], "deepspeed": [ ("bigscience/bloomz", 36.77314954096159), From 7ca1876e5e72f3693889a321590c7fa83947731d Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Sun, 28 Apr 2024 04:49:11 +0000 Subject: [PATCH 6/6] fix merge conflict --- optimum/habana/transformers/models/mixtral/modeling_mixtral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 0896b40aea..3e22f49629 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -353,6 +353,8 @@ def forward( if not output_attentions: attn_weights = None + return attn_output, attn_weights, past_key_value + def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """