From 600c7352634da044a994188ca8e9309df403c3e8 Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Tue, 26 Mar 2024 15:18:23 +0800 Subject: [PATCH 01/21] add reuse_cache support --- .../habana/transformers/generation/utils.py | 1 + optimum/habana/transformers/modeling_utils.py | 12 +- .../habana/transformers/models/__init__.py | 6 +- .../transformers/models/mixtral/__init__.py | 6 +- .../models/mixtral/modeling_mixtral.py | 670 ++++++++++-------- 5 files changed, 397 insertions(+), 298 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bad54641d6..c0d71c2f68 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -581,6 +581,7 @@ def generate( assert self.config.model_type in [ "llama", "mistral", + "mixtral", ], "reuse_cache only supported by llama and mistral at the moment" if not generation_config.bucket_internal: assert ( diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 9d4e473aab..f746dd736f 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -46,6 +46,9 @@ GaudiMistralDecoderLayer, GaudiMistralForCausalLM, GaudiMistralModel, + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, + GaudiMixtralModel, GaudiMixtralForCausalLM, GaudiMptForCausalLM, GaudiMptModel, @@ -103,10 +106,7 @@ gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, gaudi_mistral_rmsnorm_forward, - gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, gaudi_mpt_attention_forward, gaudi_mpt_block_forward, @@ -349,10 +349,10 @@ def adapt_transformers_to_gaudi(): # Optimization for mixtral on Gaudi transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM - transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = gaudi_mixtral_model_forward - transformers.models.mixtral.modeling_mixtral.MixtralAttention.forward = gaudi_mixtral_attention_forward + transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel + transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_sparse_moe_forward - transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = gaudi_mixtral_decoder_layer_forward + transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward # Optimization for speecht5 on Gaudi diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index d0eb8b2dcd..39ca2a57ee 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -88,11 +88,11 @@ gaudi_mistral_rmsnorm_forward, ) from .mixtral import ( + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, + GaudiMixtralModel, GaudiMixtralForCausalLM, - gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) from .modeling_all_models import ( diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index fd1829bbe2..cd4d0724e9 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -1,8 +1,8 @@ from .modeling_mixtral import ( + GaudiMixtralAttention, + GaudiMixtralDecoderLayer, + GaudiMixtralModel, GaudiMixtralForCausalLM, - gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, - gaudi_mixtral_decoder_layer_forward, - gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index c6f5f51ab7..8cb549dcc7 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -37,9 +37,13 @@ ) from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralModel, MixtralForCausalLM, apply_rotary_pos_emb, load_balancing_loss_func, + MixtralConfig, ) from transformers.utils import logging @@ -177,105 +181,156 @@ def forward(self, cur, dim, idx): return update(self.cache, cur, dim, idx, self.inp_seq_len) -def gaudi_mixtral_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Copied from MixtralAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - add new args token_idx - - optimize KV cache - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." +class GaudiMixtralAttention(MixtralAttention): + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from MixtralAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args reuse_cache + - add new args flash_attention_recompute + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - kv_seq_len = past_key_value.key_cache[self.layer_idx].shape[-2] - else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + if reuse_cache: + kv_seq_len = past_key_value[0][0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] + # import pdb;pdb.set_trace() + # kv_shape = ( + # (past_key_value[0][-2] if reuse_cache else past_key_value[0].shape[-2]) + # if isinstance(past_key_value, tuple) + # else past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # ) + # if token_idx is not None: + # kv_seq_len = kv_shape + # else: + # kv_seq_len += kv_shape + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None or reuse_cache: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + if token_idx is not None: + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + else: + key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + else: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if use_cache: + if reuse_cache: + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + past_key_value = (key_states.contiguous(), value_states.contiguous()) else: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + past_key_value = None - if FusedSDPA: - import habana_frameworks.torch.hpu as ht + if FusedSDPA: + import habana_frameworks.torch.hpu as ht - if q_len == 1: - # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) + if q_len == 1: + # next token + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + # first token + with ht.sdp_kernel(enable_recompute=flash_attention_recompute): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: - # first token - with ht.sdp_kernel(enable_recompute=False): # inference: flash_attention_recompute = False - attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None) - else: - query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) + query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.norm_factor - if attention_mask is not None: - attention_mask = attention_mask.unsqueeze(2) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: + attention_mask = attention_mask.unsqueeze(2) + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() + attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None + if not output_attentions: + attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_value + + +MIXTRAL_ATTENTION_CLASSES = { + "eager": GaudiMixtralAttention, +} def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -327,234 +382,264 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> return final_hidden_states, router_logits -def gaudi_mixtral_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - output_router_logits: Optional[bool] = False, - use_cache: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py - The only differences are: - - add new args token_idx - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) +class GaudiMixtralDecoderLayer(MixtralDecoderLayer): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size - htcore.mark_step() - residual = hidden_states + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - hidden_states = self.input_layernorm(hidden_states) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) - hidden_states = residual + hidden_states - htcore.mark_step() - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states, router_logits = self.block_sparse_moe(hidden_states) - hidden_states = residual + hidden_states - htcore.mark_step() - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - if output_router_logits: - outputs += (router_logits,) - - return outputs - - -def gaudi_mixtral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, MoeModelOutputWithPast]: - """ - Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 - The only differences are: - - add new args token_idx - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from MixtralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - add new args token_idx + - add new args reuse_cache + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + htcore.mark_step() + residual = hidden_states - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + ) + hidden_states = residual + hidden_states + htcore.mark_step() + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + htcore.mark_step() - past_key_values_length = 0 + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) - if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs - if self.config._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, + +class GaudiMixtralModel(MixtralModel): + def __init__(self, config: MixtralConfig): + super().__init__(config) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + ) -> Union[Tuple, MoeModelOutputWithPast]: + """ + Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 + The only differences are: + - add new args token_idx + - add new args reuse_cache + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + use_cache = use_cache if use_cache is not None else self.config.use_cache - hidden_states = inputs_embeds + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_router_logits = () if output_router_logits else None - next_decoder_cache = None + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) + past_key_values_length = 0 + use_new_cache = False if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if past_key_values is not None and use_cache: + if reuse_cache: + past_key_values_length = past_key_values[0][0][2] + else: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self.config._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self.config._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, - position_ids, - past_key_values, - output_attentions, - output_router_logits, - use_cache, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, - token_idx=token_idx, + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, ) - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if not use_new_cache else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) - if output_router_logits: - all_router_logits += (layer_outputs[-1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) - hidden_states = self.norm(hidden_states) + hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] - if v is not None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - router_logits=all_router_logits, - ) class GaudiMixtralForCausalLM(MixtralForCausalLM): @@ -566,6 +651,10 @@ class GaudiMixtralForCausalLM(MixtralForCausalLM): - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + self.kv_cache_len = max_seq_len def forward( self, @@ -581,6 +670,7 @@ def forward( output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -605,6 +695,7 @@ def forward( output_router_logits=output_router_logits, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, ) hidden_states = outputs[0] @@ -654,6 +745,8 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 + reuse_cache = kwargs.get("reuse_cache") token_idx = kwargs.get("token_idx", None) # Omit tokens covered by past_key_values @@ -688,6 +781,10 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] else: input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -713,6 +810,7 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": reuse_cache, } ) return model_inputs From 76906140e53df16bf139cb1c0272d44cfcdeb741 Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Tue, 26 Mar 2024 15:54:58 +0800 Subject: [PATCH 02/21] make style --- optimum/habana/transformers/modeling_utils.py | 2 +- .../habana/transformers/models/__init__.py | 2 +- .../transformers/models/mixtral/__init__.py | 2 +- .../models/mixtral/modeling_mixtral.py | 29 ++++++------------- 4 files changed, 12 insertions(+), 23 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index f746dd736f..d03dc2261a 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -48,8 +48,8 @@ GaudiMistralModel, GaudiMixtralAttention, GaudiMixtralDecoderLayer, - GaudiMixtralModel, GaudiMixtralForCausalLM, + GaudiMixtralModel, GaudiMptForCausalLM, GaudiMptModel, GaudiOPTForCausalLM, diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 39ca2a57ee..731552afb4 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -90,8 +90,8 @@ from .mixtral import ( GaudiMixtralAttention, GaudiMixtralDecoderLayer, - GaudiMixtralModel, GaudiMixtralForCausalLM, + GaudiMixtralModel, gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index cd4d0724e9..d95c0c2cdd 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -1,8 +1,8 @@ from .modeling_mixtral import ( GaudiMixtralAttention, GaudiMixtralDecoderLayer, - GaudiMixtralModel, GaudiMixtralForCausalLM, + GaudiMixtralModel, gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 8cb549dcc7..6509ff7263 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -38,12 +38,12 @@ from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, + MixtralConfig, MixtralDecoderLayer, - MixtralModel, MixtralForCausalLM, + MixtralModel, apply_rotary_pos_emb, load_balancing_loss_func, - MixtralConfig, ) from transformers.utils import logging @@ -65,6 +65,7 @@ except ImportError: print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +FusedSDPA = None logger = logging.get_logger(__name__) @@ -147,10 +148,6 @@ def gaudi_mixtral_repeat_kv( new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) query_states = query_states.reshape(new_q_shape) - if attention_mask is not None: - # Add groups dim and set to 1 - attention_mask = attention_mask.unsqueeze(1) - return query_states, key_states, value_states, attention_mask @@ -250,16 +247,6 @@ def forward( kv_seq_len = past_key_value[0][0][-2] else: kv_seq_len = past_key_value[0].shape[-2] - # import pdb;pdb.set_trace() - # kv_shape = ( - # (past_key_value[0][-2] if reuse_cache else past_key_value[0].shape[-2]) - # if isinstance(past_key_value, tuple) - # else past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # ) - # if token_idx is not None: - # kv_seq_len = kv_shape - # else: - # kv_seq_len += kv_shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) @@ -274,7 +261,9 @@ def forward( key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) else: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) if use_cache: if reuse_cache: @@ -459,10 +448,10 @@ def forward( class GaudiMixtralModel(MixtralModel): def __init__(self, config: MixtralConfig): super().__init__(config) - + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): for layer in self.layers: - layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) def forward( self, @@ -651,7 +640,7 @@ class GaudiMixtralForCausalLM(MixtralForCausalLM): - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ - + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) self.kv_cache_len = max_seq_len From 67535f0e8e2c1ec8baa9eddb87d384874015c382 Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Tue, 26 Mar 2024 15:54:58 +0800 Subject: [PATCH 03/21] make style --- optimum/habana/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index c0d71c2f68..d6added203 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -582,7 +582,7 @@ def generate( "llama", "mistral", "mixtral", - ], "reuse_cache only supported by llama and mistral at the moment" + ], "reuse_cache only supported by llama, mistral and mixtral at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 From 838b55f53d5861e0ad964f2f488941dd45570780 Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Tue, 26 Mar 2024 16:01:42 +0800 Subject: [PATCH 04/21] remove debug code --- optimum/habana/transformers/models/mixtral/modeling_mixtral.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 6509ff7263..3e1c428ba0 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -65,7 +65,6 @@ except ImportError: print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None -FusedSDPA = None logger = logging.get_logger(__name__) From 458375a9afca2d0cf208b5166f9fc8c4d817b33e Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Tue, 26 Mar 2024 17:06:45 +0800 Subject: [PATCH 05/21] add fp8 support of non-sdpa attn --- .../models/mixtral/modeling_mixtral.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 3e1c428ba0..a4d1f8bd64 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -150,6 +150,14 @@ def gaudi_mixtral_repeat_kv( return query_states, key_states, value_states, attention_mask +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() @@ -181,6 +189,8 @@ class GaudiMixtralAttention(MixtralAttention): def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) + self.matmul_qk = Matmul() + self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() self.inp_seq_len = -1 @@ -292,7 +302,7 @@ def forward( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.norm_factor + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor if attention_mask is not None: attention_mask = attention_mask.unsqueeze(2) @@ -301,7 +311,7 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = self.matmul_av(attn_weights, value_states) attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() From ac3f004c674044c99ef490ffe7ef8daa0bf1c087 Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Tue, 26 Mar 2024 17:34:54 +0800 Subject: [PATCH 06/21] add bucket_internal support of Mixtral --- .../models/mixtral/modeling_mixtral.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index a4d1f8bd64..566130f7dd 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -214,6 +214,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -269,6 +270,13 @@ def forward( else: key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] else: key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs @@ -401,6 +409,8 @@ def forward( use_cache: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -429,6 +439,8 @@ def forward( use_cache=use_cache, token_idx=token_idx, reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = residual + hidden_states htcore.mark_step() @@ -476,6 +488,8 @@ def forward( return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, MoeModelOutputWithPast]: """ Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069 @@ -600,6 +614,8 @@ def forward( use_cache=use_cache, token_idx=token_idx, reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = layer_outputs[0] @@ -669,6 +685,8 @@ def forward( return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = None, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -694,6 +712,8 @@ def forward( return_dict=return_dict, token_idx=token_idx, reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, ) hidden_states = outputs[0] @@ -809,6 +829,8 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, "token_idx": token_idx, "reuse_cache": reuse_cache, + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs From 6816af0e96fd326b56bb3f0fa86925131e622ec2 Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Wed, 3 Apr 2024 13:58:31 +0800 Subject: [PATCH 07/21] fit to r1.15 and also fp8 sdpa --- .../maxabs_quant_mixtral.json | 4 +- .../habana/transformers/generation/utils.py | 1 + .../models/mixtral/modeling_mixtral.py | 246 ++++++++++++------ 3 files changed, 174 insertions(+), 77 deletions(-) diff --git a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json index 737edcc413..b3fd2e26db 100644 --- a/examples/text-generation/quantization_config/maxabs_quant_mixtral.json +++ b/examples/text-generation/quantization_config/maxabs_quant_mixtral.json @@ -3,8 +3,8 @@ "mode": "QUANTIZE", "observer": "maxabs", "scale_method": "maxabs_hw", - "whitelist": {"types": [], "names": ["gate","w1","w3","w2"]}, - "blacklist": {"types": [], "names": [ + "allowlist": {"types": [], "names": ["gate","w1","w3","w2"]}, + "blocklist": {"types": [], "names": [ "model.layers.1.block_sparse_moe.experts.(3|4).w2", "model.layers.[29-31].block_sparse_moe.experts.[0-7].w2" ]}, diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ffc02be17b..ff2e34fac8 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -584,6 +584,7 @@ def generate( assert self.config.model_type in [ "llama", "mistral", + "mixtral", "falcon", ], "reuse_cache only supported by llama, mistral, falcon and mixtral at the moment" if not generation_config.bucket_internal: diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 566130f7dd..1fd3330181 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -20,6 +20,8 @@ """PyTorch Mixtral model.""" +import contextlib +import os import math import warnings from typing import List, Optional, Tuple, Union @@ -42,6 +44,7 @@ MixtralDecoderLayer, MixtralForCausalLM, MixtralModel, + repeat_kv, apply_rotary_pos_emb, load_balancing_loss_func, ) @@ -66,33 +69,22 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None -logger = logging.get_logger(__name__) - +try: + from habana_frameworks.torch.hpu import sdp_kernel + SDPContext = True +except ImportError: + SDPContext = False -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.dtype == torch.float8_e4m3fn: - from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 - - cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) +logger = logging.get_logger(__name__) def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids - ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) @@ -150,12 +142,86 @@ def gaudi_mixtral_repeat_kv( return query_states, key_states, value_states, attention_mask -class Matmul(torch.nn.Module): +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, invAttnHead=None): + return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) + + +class Matmul(nn.Module): def __init__(self): super().__init__() - def forward(self, x, y): - return torch.matmul(x, y) + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +class ScaledDotProductAttention(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + # self.num_heads = config.num_attention_heads + # self.num_kv_groups = config.num_attention_heads // config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.bmm1 = Matmul() + self.bmm2 = Matmul() + self.softmax = Softmax() + + def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(self.head_dim) + invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") + + if is_causal: + assert attn_mask is None + attn_bias = torch.zeros(L, S, dtype=query.dtype) + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + + attn_weight = self.bmm1(query, key.transpose(-2, -1)) + + attn_weight += attn_mask + attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + return self.bmm2(attn_weight, value) + + # Try broadcast matmuls but seems not supported currently + ''' + def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: + query, key, value, attn_mask = gaudi_mixtral_repeat_kv( + query, key, value, attn_mask, self.num_kv_groups + ) + bsz = query.size(0) + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(self.head_dim) + invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") + + if is_causal: + assert attn_mask is None + attn_bias = torch.zeros(L, S, dtype=query.dtype) + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + + attn_weight = self.bmm1(query, key.transpose(-2, -1)) + + attn_weight += attn_mask + attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + attn_output = self.bmm2(attn_weight, value) + attn_output = attn_output.reshape(bsz, self.num_heads, L, self.head_dim).contiguous() + return attn_output + ''' class KVCache(torch.nn.Module): @@ -164,11 +230,9 @@ def __init__(self): self.cache = None self.inp_seq_len = -1 - def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + def allocate(self, inp_seq_len, dtype, device, shape): if self.cache is None or self.cache.shape != shape: self.inp_seq_len = inp_seq_len - if kv_cache_fp8: - dtype = torch.float8_e4m3fn self.cache = torch.zeros(shape, dtype=dtype, device=device) else: assert ( @@ -176,32 +240,49 @@ def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" self.cache.fill_(0) + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + def get_shape(self): if self.cache is None: return None return self.cache.shape def forward(self, cur, dim, idx): - return update(self.cache, cur, dim, idx, self.inp_seq_len) + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) class GaudiMixtralAttention(MixtralAttention): def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.matmul_qk = Matmul() - self.matmul_av = Matmul() + if os.getenv("QUANT_CONFIG", ""): + self.sdpa = ScaledDotProductAttention(config) + self.k_cache = KVCache() self.v_cache = KVCache() self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) device = self.k_proj.weight.device dtype = self.config.torch_dtype - self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) - self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def forward( self, @@ -251,48 +332,61 @@ def forward( if hasattr(past_key_value, "get_usable_length"): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: - kv_seq_len += past_key_value[0].shape[-2] + kv_seq_len += past_key_value[0][0].shape[-2] else: if reuse_cache: kv_seq_len = past_key_value[0][0][-2] else: - kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len = past_key_value[0][0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None or reuse_cache: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - if token_idx is not None: - if reuse_cache: - key_states = self.k_cache(key_states, 2, token_idx) - value_states = self.v_cache(value_states, 2, token_idx) - else: - key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) - - if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] - else: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - if use_cache: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - past_key_value = (key_states.contiguous(), value_states.contiguous()) + if past_key_value is None: + past_key = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] else: past_key_value = None if FusedSDPA: + if os.getenv("QUANT_CONFIG", ""): + # WA for GQA optimization is not supported currently + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output = self.sdpa( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + # pervious version + ''' import habana_frameworks.torch.hpu as ht - if q_len == 1: # next token with ht.sdp_kernel(enable_recompute=False): @@ -305,12 +399,13 @@ def forward( attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, 0.0, False, None ) + ''' else: query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.norm_factor + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) * self.norm_factor if attention_mask is not None: attention_mask = attention_mask.unsqueeze(2) @@ -319,7 +414,7 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim).contiguous() @@ -395,8 +490,8 @@ def __init__(self, config: MixtralConfig, layer_idx: int): self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def forward( self, @@ -470,9 +565,9 @@ class GaudiMixtralModel(MixtralModel): def __init__(self, config: MixtralConfig): super().__init__(config) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def forward( self, @@ -519,7 +614,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") past_key_values_length = 0 - use_new_cache = False + use_new_cache = False # Ignoring new Cache path for HPU if self.gradient_checkpointing and self.training: if use_cache: @@ -532,10 +627,12 @@ def forward( if reuse_cache: past_key_values_length = past_key_values[0][0][2] else: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if use_new_cache: + if not instance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + past_seen_tokens = past_key_values[0][0].shape[2] if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -666,8 +763,8 @@ class GaudiMixtralForCausalLM(MixtralForCausalLM): - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len, kv_cache_fp8) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) self.kv_cache_len = max_seq_len def forward( @@ -767,9 +864,10 @@ def prepare_inputs_for_generation( reuse_cache = kwargs.get("reuse_cache") token_idx = kwargs.get("token_idx", None) - # Omit tokens covered by past_key_values if past_key_values is not None: - if token_idx is None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens @@ -797,8 +895,6 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) elif reuse_cache and token_idx is not None: # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass input_ids = input_ids[:, :token_idx] From de9d07f0c16ab7d9b4ac5bc465f4f86b4800b4f8 Mon Sep 17 00:00:00 2001 From: Jinyan chen Date: Wed, 3 Apr 2024 14:07:12 +0800 Subject: [PATCH 08/21] make style --- .../models/mixtral/modeling_mixtral.py | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 1fd3330181..7edd38d6c1 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -21,8 +21,8 @@ """PyTorch Mixtral model.""" import contextlib -import os import math +import os import warnings from typing import List, Optional, Tuple, Union @@ -31,7 +31,7 @@ import torch.nn.functional as F from torch import nn from torch.nn import CrossEntropyLoss -from transformers.cache_utils import Cache, DynamicCache +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, @@ -44,9 +44,9 @@ MixtralDecoderLayer, MixtralForCausalLM, MixtralModel, - repeat_kv, apply_rotary_pos_emb, load_balancing_loss_func, + repeat_kv, ) from transformers.utils import logging @@ -71,6 +71,7 @@ try: from habana_frameworks.torch.hpu import sdp_kernel + SDPContext = True except ImportError: SDPContext = False @@ -192,7 +193,7 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa return self.bmm2(attn_weight, value) # Try broadcast matmuls but seems not supported currently - ''' + """ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: query, key, value, attn_mask = gaudi_mixtral_repeat_kv( query, key, value, attn_mask, self.num_kv_groups @@ -221,7 +222,7 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa attn_output = self.bmm2(attn_weight, value) attn_output = attn_output.reshape(bsz, self.num_heads, L, self.head_dim).contiguous() return attn_output - ''' + """ class KVCache(torch.nn.Module): @@ -343,16 +344,13 @@ def forward( query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) if use_cache: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models if reuse_cache: key_states = self.k_cache(key_states, 2, token_idx) value_states = self.v_cache(value_states, 2, token_idx) past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device - ) + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) past_value = torch.zeros( key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) @@ -376,16 +374,14 @@ def forward( # WA for GQA optimization is not supported currently key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output = self.sdpa( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + attn_output = self.sdpa(query_states, key_states, value_states, attention_mask, 0.0, False, None) else: with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, 0.0, False, None ) # pervious version - ''' + """ import habana_frameworks.torch.hpu as ht if q_len == 1: # next token @@ -399,7 +395,7 @@ def forward( attn_output = FusedSDPA.apply( query_states, key_states, value_states, attention_mask, 0.0, False, None ) - ''' + """ else: query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups @@ -614,7 +610,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") past_key_values_length = 0 - use_new_cache = False # Ignoring new Cache path for HPU + use_new_cache = False # Ignoring new Cache path for HPU if self.gradient_checkpointing and self.training: if use_cache: @@ -628,11 +624,9 @@ def forward( past_key_values_length = past_key_values[0][0][2] else: if use_new_cache: - if not instance(past_key_values, StaticCache): + if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() - else: - past_seen_tokens = past_key_values[0][0].shape[2] + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device From 4fa87415e956d821652890a4fdd0458ef5995e94 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Tue, 9 Apr 2024 03:07:35 -0700 Subject: [PATCH 09/21] support long sequence prompt --- .../models/mixtral/modeling_mixtral.py | 90 ++++++++----------- 1 file changed, 38 insertions(+), 52 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 7edd38d6c1..2fa29cec8a 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -192,37 +192,33 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return self.bmm2(attn_weight, value) - # Try broadcast matmuls but seems not supported currently - """ - def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: - query, key, value, attn_mask = gaudi_mixtral_repeat_kv( - query, key, value, attn_mask, self.num_kv_groups - ) - bsz = query.size(0) - L, S = query.size(-2), key.size(-2) - scale_factor = 1 / math.sqrt(self.head_dim) - invAttnHead = torch.tensor(scale_factor, dtype=torch.float32).to("hpu") - if is_causal: - assert attn_mask is None - attn_bias = torch.zeros(L, S, dtype=query.dtype) - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) - attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) +class NaiveFlashAttention(nn.Module): + @staticmethod + def forward(q, k, v, mask, causal, q_bucket_size, k_bucket_size, scale): + """ + Support long sequence prompt + """ + bsz, num_heads, head_dim = q.size(0), q.size(1), q.size(-1) + q_len = q.size(-2) + kvlen = k.size(-2) + query_tiles = q_len // q_bucket_size + query_states = query_states.reshape(bsz, num_heads, q_bucket_size, query_tiles, head_dim) - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + if mask is not None: + mask = mask.unsqueeze(2) + mask = mask.reshape(bsz, 1, q_bucket_size, query_tiles, kvlen) - attn_weight = self.bmm1(query, key.transpose(-2, -1)) + attn_output = [] + for i in range(query_tiles): + row_q = query_states[:,:,:,i,:] + row_mask = mask[:,:,:,i,:] + + row_o = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None) + attn_output.append(row_o) + attn_output = torch.cat(attn_output, dim=2) - attn_weight += attn_mask - attn_weight = self.softmax(attn_weight, dim=-1, invAttnHead=invAttnHead) - attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - attn_output = self.bmm2(attn_weight, value) - attn_output = attn_output.reshape(bsz, self.num_heads, L, self.head_dim).contiguous() return attn_output - """ class KVCache(torch.nn.Module): @@ -277,6 +273,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.bucket_size = 1024 def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) @@ -370,32 +367,24 @@ def forward( past_key_value = None if FusedSDPA: - if os.getenv("QUANT_CONFIG", ""): - # WA for GQA optimization is not supported currently + if not self.training and q_len == key_states.size(-2) and \ + q_len >= 8192 and q_len % self.bucket_size == 0: key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output = self.sdpa(query_states, key_states, value_states, attention_mask, 0.0, False, None) - else: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) - # pervious version - """ - import habana_frameworks.torch.hpu as ht - if q_len == 1: - # next token - with ht.sdp_kernel(enable_recompute=False): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) + attn_output = NaiveFlashAttention.forward(query_states, key_states, value_states, + attention_mask, False, self.bucket_size, self.bucket_size, self.norm_factor + ) else: - # first token - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) - """ + if os.getenv("QUANT_CONFIG", ""): + # WA for GQA optimization is not supported currently + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_output = self.sdpa(query_states, key_states, value_states, attention_mask, 0.0, False, None) + else: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) else: query_states, key_states, value_states, attention_mask = gaudi_mixtral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups @@ -515,7 +504,6 @@ def forward( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - htcore.mark_step() residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -534,14 +522,12 @@ def forward( cache_idx=cache_idx, ) hidden_states = residual + hidden_states - htcore.mark_step() # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, router_logits = self.block_sparse_moe(hidden_states) hidden_states = residual + hidden_states - htcore.mark_step() outputs = (hidden_states,) From 0a3b7ace03d24a2a8bacd2b01c72ec3c857380c5 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Tue, 9 Apr 2024 18:23:21 +0800 Subject: [PATCH 10/21] make style --- optimum/habana/transformers/models/mixtral/modeling_mixtral.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 2fa29cec8a..4da739298d 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -26,7 +26,6 @@ import warnings from typing import List, Optional, Tuple, Union -import habana_frameworks.torch.core as htcore import torch import torch.nn.functional as F from torch import nn From cb74e209c0df870d8b35702c646c613041b2a2dd Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Tue, 9 Apr 2024 18:26:21 +0800 Subject: [PATCH 11/21] make style --- .../models/mixtral/modeling_mixtral.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 4da739298d..748dcecfd0 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -194,24 +194,24 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa class NaiveFlashAttention(nn.Module): @staticmethod - def forward(q, k, v, mask, causal, q_bucket_size, k_bucket_size, scale): + def forward(q, k, v, mask, causal, q_bucket_size, k_bucket_size=None, scale=None): """ Support long sequence prompt """ bsz, num_heads, head_dim = q.size(0), q.size(1), q.size(-1) q_len = q.size(-2) kvlen = k.size(-2) - query_tiles = q_len // q_bucket_size - query_states = query_states.reshape(bsz, num_heads, q_bucket_size, query_tiles, head_dim) + q_tiles = q_len // q_bucket_size + q = q.reshape(bsz, num_heads, q_bucket_size, q_tiles, head_dim) if mask is not None: mask = mask.unsqueeze(2) - mask = mask.reshape(bsz, 1, q_bucket_size, query_tiles, kvlen) + mask = mask.reshape(bsz, 1, q_bucket_size, q_tiles, kvlen) attn_output = [] - for i in range(query_tiles): - row_q = query_states[:,:,:,i,:] - row_mask = mask[:,:,:,i,:] + for i in range(q_tiles): + row_q = q[:, :, :, i, :] + row_mask = mask[:, :, :, i, :] row_o = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None) attn_output.append(row_o) @@ -366,12 +366,18 @@ def forward( past_key_value = None if FusedSDPA: - if not self.training and q_len == key_states.size(-2) and \ - q_len >= 8192 and q_len % self.bucket_size == 0: + if not self.training and q_len == key_states.size(-2) and q_len >= 8192 and q_len % self.bucket_size == 0: key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_output = NaiveFlashAttention.forward(query_states, key_states, value_states, - attention_mask, False, self.bucket_size, self.bucket_size, self.norm_factor + attn_output = NaiveFlashAttention.forward( + query_states, + key_states, + value_states, + attention_mask, + False, + self.bucket_size, + self.bucket_size, + self.norm_factor, ) else: if os.getenv("QUANT_CONFIG", ""): From ef96dc1b187708727165ad35d2af89f7c7306c96 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Tue, 9 Apr 2024 22:52:35 -0700 Subject: [PATCH 12/21] update long seq support --- .../habana/transformers/models/mixtral/modeling_mixtral.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 748dcecfd0..023ef74a95 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -210,8 +210,8 @@ def forward(q, k, v, mask, causal, q_bucket_size, k_bucket_size=None, scale=None attn_output = [] for i in range(q_tiles): - row_q = q[:, :, :, i, :] - row_mask = mask[:, :, :, i, :] + row_q = q[..., i, :] + row_mask = mask[..., i, :] row_o = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None) attn_output.append(row_o) @@ -367,8 +367,6 @@ def forward( if FusedSDPA: if not self.training and q_len == key_states.size(-2) and q_len >= 8192 and q_len % self.bucket_size == 0: - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = NaiveFlashAttention.forward( query_states, key_states, From 09fcfed7dfddd4440399d6120714a12e510d6503 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Wed, 10 Apr 2024 09:29:29 +0000 Subject: [PATCH 13/21] fix bug for w/o reuse_kvcache --- .../models/mixtral/modeling_mixtral.py | 54 +++++++++++++------ 1 file changed, 39 insertions(+), 15 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 023ef74a95..a8db427c8e 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -194,14 +194,24 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa class NaiveFlashAttention(nn.Module): @staticmethod - def forward(q, k, v, mask, causal, q_bucket_size, k_bucket_size=None, scale=None): + def forward(q, k, v, mask, causal, q_bucket_size): """ Support long sequence prompt """ bsz, num_heads, head_dim = q.size(0), q.size(1), q.size(-1) q_len = q.size(-2) kvlen = k.size(-2) - q_tiles = q_len // q_bucket_size + + q_padding = 0 + if q_len % q_bucket_size == 0: + q_tiles = q_len // q_bucket_size + else: + q_tiles = math.ceil(q_len / q_bucket_size) + q_padding = q_tiles * q_bucket_size - q_len + q = F.pad(q, (0,0,q_padding,0), "constant", 0) + if mask is not None: + mask = F.pad(mask, (0,0,q_padding,0), "constant", -3.3895e+38) + q = q.reshape(bsz, num_heads, q_bucket_size, q_tiles, head_dim) if mask is not None: @@ -217,6 +227,9 @@ def forward(q, k, v, mask, causal, q_bucket_size, k_bucket_size=None, scale=None attn_output.append(row_o) attn_output = torch.cat(attn_output, dim=2) + if q_padding != 0: + attn_output = attn_output[..., :-q_padding, :] + return attn_output @@ -329,12 +342,12 @@ def forward( if hasattr(past_key_value, "get_usable_length"): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: - kv_seq_len += past_key_value[0][0].shape[-2] + kv_seq_len += past_key_value[0].shape[-2] else: if reuse_cache: - kv_seq_len = past_key_value[0][0][-2] + kv_seq_len = past_key_value[0][-2] else: - kv_seq_len = past_key_value[0][0].shape[-2] + kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) @@ -346,13 +359,24 @@ def forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key = torch.zeros( + key_states.shape, + dtype=self.k_proj.weight.dtype, + device=key_states.device + ) past_value = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + key_states.shape, + dtype=self.k_proj.weight.dtype, + device=key_states.device ) past_key_value = (past_key, past_value) - key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + key_states = self.k_cache.update( + past_key_value[0], key_states, 2, token_idx, self.inp_seq_len + ) + value_states = self.k_cache.update( + past_key_value[1], value_states, 2, token_idx, self.inp_seq_len + ) + if token_idx is None: past_key_value = (key_states, value_states) @@ -366,7 +390,7 @@ def forward( past_key_value = None if FusedSDPA: - if not self.training and q_len == key_states.size(-2) and q_len >= 8192 and q_len % self.bucket_size == 0: + if not self.training and q_len == key_states.size(-2) and q_len >= 8192: attn_output = NaiveFlashAttention.forward( query_states, key_states, @@ -374,8 +398,6 @@ def forward( attention_mask, False, self.bucket_size, - self.bucket_size, - self.norm_factor, ) else: if os.getenv("QUANT_CONFIG", ""): @@ -615,7 +637,9 @@ def forward( if use_new_cache: if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + past_key_values_length = past_key_values.get_usable_length() + else: + past_key_values_length = past_key_values[0][0].shape[2] if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -668,7 +692,7 @@ def forward( all_router_logits = () if output_router_logits else None next_decoder_cache = () if not use_new_cache else None - for decoder_layer in self.layers: + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -688,7 +712,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, From 5fb535482681893d926f47b21007ff0207c38819 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Wed, 10 Apr 2024 17:41:51 +0800 Subject: [PATCH 14/21] make style --- .../models/mixtral/modeling_mixtral.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index a8db427c8e..6b237e63c7 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -208,9 +208,9 @@ def forward(q, k, v, mask, causal, q_bucket_size): else: q_tiles = math.ceil(q_len / q_bucket_size) q_padding = q_tiles * q_bucket_size - q_len - q = F.pad(q, (0,0,q_padding,0), "constant", 0) + q = F.pad(q, (0, 0, q_padding, 0), "constant", 0) if mask is not None: - mask = F.pad(mask, (0,0,q_padding,0), "constant", -3.3895e+38) + mask = F.pad(mask, (0, 0, q_padding, 0), "constant", -3.3895e38) q = q.reshape(bsz, num_heads, q_bucket_size, q_tiles, head_dim) @@ -359,23 +359,13 @@ def forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros( - key_states.shape, - dtype=self.k_proj.weight.dtype, - device=key_states.device - ) + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) past_value = torch.zeros( - key_states.shape, - dtype=self.k_proj.weight.dtype, - device=key_states.device + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) past_key_value = (past_key, past_value) - key_states = self.k_cache.update( - past_key_value[0], key_states, 2, token_idx, self.inp_seq_len - ) - value_states = self.k_cache.update( - past_key_value[1], value_states, 2, token_idx, self.inp_seq_len - ) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if token_idx is None: past_key_value = (key_states, value_states) From 8d413026c92dff545ba05645b8e2305164864c0f Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Thu, 11 Apr 2024 01:43:16 +0000 Subject: [PATCH 15/21] fix accuracy of NaiveFA --- .../models/mixtral/modeling_mixtral.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 6b237e63c7..96b978e517 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -198,9 +198,7 @@ def forward(q, k, v, mask, causal, q_bucket_size): """ Support long sequence prompt """ - bsz, num_heads, head_dim = q.size(0), q.size(1), q.size(-1) q_len = q.size(-2) - kvlen = k.size(-2) q_padding = 0 if q_len % q_bucket_size == 0: @@ -208,27 +206,22 @@ def forward(q, k, v, mask, causal, q_bucket_size): else: q_tiles = math.ceil(q_len / q_bucket_size) q_padding = q_tiles * q_bucket_size - q_len - q = F.pad(q, (0, 0, q_padding, 0), "constant", 0) + q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) if mask is not None: - mask = F.pad(mask, (0, 0, q_padding, 0), "constant", -3.3895e38) - - q = q.reshape(bsz, num_heads, q_bucket_size, q_tiles, head_dim) - - if mask is not None: - mask = mask.unsqueeze(2) - mask = mask.reshape(bsz, 1, q_bucket_size, q_tiles, kvlen) + mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) attn_output = [] for i in range(q_tiles): - row_q = q[..., i, :] - row_mask = mask[..., i, :] + s, e = i * q_bucket_size, (i + 1) * q_bucket_size + row_q = q[:, :, s : e, :] + row_mask = mask[:, :, s : e, :] row_o = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None) attn_output.append(row_o) - attn_output = torch.cat(attn_output, dim=2) + attn_output = torch.cat(attn_output, dim=-2) if q_padding != 0: - attn_output = attn_output[..., :-q_padding, :] + attn_output = attn_output[:, :, :-q_padding, :] return attn_output @@ -359,13 +352,23 @@ def forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key = torch.zeros( + key_states.shape, + dtype=self.k_proj.weight.dtype, + device=key_states.device + ) past_value = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + key_states.shape, + dtype=self.k_proj.weight.dtype, + device=key_states.device ) past_key_value = (past_key, past_value) - key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + key_states = self.k_cache.update( + past_key_value[0], key_states, 2, token_idx, self.inp_seq_len + ) + value_states = self.k_cache.update( + past_key_value[1], value_states, 2, token_idx, self.inp_seq_len + ) if token_idx is None: past_key_value = (key_states, value_states) From 416720da5b03e884d39ec616141c37f643ffa8cf Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Thu, 11 Apr 2024 09:55:36 +0800 Subject: [PATCH 16/21] update thresh of long seq --- .../models/mixtral/modeling_mixtral.py | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 96b978e517..dca7c38961 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -213,8 +213,8 @@ def forward(q, k, v, mask, causal, q_bucket_size): attn_output = [] for i in range(q_tiles): s, e = i * q_bucket_size, (i + 1) * q_bucket_size - row_q = q[:, :, s : e, :] - row_mask = mask[:, :, s : e, :] + row_q = q[:, :, s:e, :] + row_mask = mask[:, :, s:e, :] row_o = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None) attn_output.append(row_o) @@ -352,23 +352,13 @@ def forward( past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: - past_key = torch.zeros( - key_states.shape, - dtype=self.k_proj.weight.dtype, - device=key_states.device - ) + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) past_value = torch.zeros( - key_states.shape, - dtype=self.k_proj.weight.dtype, - device=key_states.device + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ) past_key_value = (past_key, past_value) - key_states = self.k_cache.update( - past_key_value[0], key_states, 2, token_idx, self.inp_seq_len - ) - value_states = self.k_cache.update( - past_key_value[1], value_states, 2, token_idx, self.inp_seq_len - ) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.k_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if token_idx is None: past_key_value = (key_states, value_states) @@ -383,7 +373,7 @@ def forward( past_key_value = None if FusedSDPA: - if not self.training and q_len == key_states.size(-2) and q_len >= 8192: + if not self.training and q_len == key_states.size(-2) and q_len > 8192: attn_output = NaiveFlashAttention.forward( query_states, key_states, From 6cdedcf1134a7d4f7b538940ec88b800c7d7305f Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Thu, 11 Apr 2024 02:53:15 +0000 Subject: [PATCH 17/21] update bucket size for naive fa --- .../habana/transformers/models/mixtral/modeling_mixtral.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index dca7c38961..310321675f 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -49,6 +49,7 @@ ) from transformers.utils import logging +import habana_frameworks.torch.core as htcore try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -278,7 +279,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) - self.bucket_size = 1024 + self.bucket_size = 4096 # 1024 def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) @@ -374,6 +375,7 @@ def forward( if FusedSDPA: if not self.training and q_len == key_states.size(-2) and q_len > 8192: + htcore.mark_step() attn_output = NaiveFlashAttention.forward( query_states, key_states, @@ -382,6 +384,7 @@ def forward( False, self.bucket_size, ) + htcore.mark_step() else: if os.getenv("QUANT_CONFIG", ""): # WA for GQA optimization is not supported currently From 745bb4b4a716ff7c0e61bc48b69d747cbacc714c Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Thu, 11 Apr 2024 11:07:16 +0800 Subject: [PATCH 18/21] make style --- .../habana/transformers/models/mixtral/modeling_mixtral.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 310321675f..5bcf440ff1 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -26,6 +26,7 @@ import warnings from typing import List, Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch import torch.nn.functional as F from torch import nn @@ -49,7 +50,6 @@ ) from transformers.utils import logging -import habana_frameworks.torch.core as htcore try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -279,7 +279,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) - self.bucket_size = 4096 # 1024 + self.bucket_size = 4096 # 1024 def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) From 3e0ca1988183b4b2700370bfd3926844ce81d393 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Thu, 11 Apr 2024 03:09:37 +0000 Subject: [PATCH 19/21] tune bucket size --- optimum/habana/transformers/models/mixtral/modeling_mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 5bcf440ff1..8e340fc6b7 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -279,7 +279,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.inp_seq_len = -1 self.norm_factor = 1.0 / math.sqrt(self.head_dim) - self.bucket_size = 4096 # 1024 + self.bucket_size = 1024 # 1024 def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) From 9b0bad6716981efb6235f68c02a7d9d7f8426e37 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Thu, 11 Apr 2024 18:53:44 -0700 Subject: [PATCH 20/21] update naive fa --- .../models/mixtral/modeling_mixtral.py | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 8e340fc6b7..5eac0abd7b 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -50,7 +50,6 @@ ) from transformers.utils import logging - try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE except ImportError: @@ -200,26 +199,21 @@ def forward(q, k, v, mask, causal, q_bucket_size): Support long sequence prompt """ q_len = q.size(-2) + q_tiles = (q_len // q_bucket_size) if (q_len % q_bucket_size == 0) else math.ceil(q_len / q_bucket_size) + q_padding = (q_tiles * q_bucket_size - q_len) + q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) + if mask is not None: + mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) + attn_output = torch.zeros_like(q) + + row_tiles = zip( + q.split(q_bucket_size, dim=-2), + attn_output.split(q_bucket_size, dim=-2), + mask.split(q_bucket_size, dim=-2), + ) - q_padding = 0 - if q_len % q_bucket_size == 0: - q_tiles = q_len // q_bucket_size - else: - q_tiles = math.ceil(q_len / q_bucket_size) - q_padding = q_tiles * q_bucket_size - q_len - q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) - if mask is not None: - mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) - - attn_output = [] - for i in range(q_tiles): - s, e = i * q_bucket_size, (i + 1) * q_bucket_size - row_q = q[:, :, s:e, :] - row_mask = mask[:, :, s:e, :] - - row_o = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None) - attn_output.append(row_o) - attn_output = torch.cat(attn_output, dim=-2) + for row_q, row_o, row_mask in row_tiles: + row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None)) if q_padding != 0: attn_output = attn_output[:, :, :-q_padding, :] From e9fba18750173c736082cb7a20ccb365667e9d76 Mon Sep 17 00:00:00 2001 From: Jinyan Chen Date: Fri, 12 Apr 2024 10:04:59 +0800 Subject: [PATCH 21/21] make style --- optimum/habana/transformers/models/mixtral/modeling_mixtral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 5eac0abd7b..68ba890fe6 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -50,6 +50,7 @@ ) from transformers.utils import logging + try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE except ImportError: @@ -200,7 +201,7 @@ def forward(q, k, v, mask, causal, q_bucket_size): """ q_len = q.size(-2) q_tiles = (q_len // q_bucket_size) if (q_len % q_bucket_size == 0) else math.ceil(q_len / q_bucket_size) - q_padding = (q_tiles * q_bucket_size - q_len) + q_padding = q_tiles * q_bucket_size - q_len q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) if mask is not None: mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0)