From e9a1e2060f4f693fdbae716abf796dddb29e20f8 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 14 Mar 2024 09:43:45 +0200 Subject: [PATCH 01/16] Initial commit reuse_cache --- .../habana/transformers/generation/utils.py | 7 +- optimum/habana/transformers/modeling_utils.py | 13 +- .../habana/transformers/models/__init__.py | 6 +- .../transformers/models/mistral/__init__.py | 6 +- .../models/mistral/modeling_mistral.py | 660 +++++++++++------- 5 files changed, 407 insertions(+), 285 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ed83b65e3d..785fbfdaab 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -577,7 +577,7 @@ def generate( if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size if generation_config.reuse_cache: - assert self.config.model_type in ["llama"], "reuse_cache only supported by llama at the moment" + assert self.config.model_type in ["llama", "mistral"], "reuse_cache only supported by llama at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 @@ -720,14 +720,16 @@ def generate( model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False model_kwargs["use_fused_rope"] = False if generation_config.use_fused_rope is False else True - + #import pdb; pdb.set_trace() if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] if not generation_config.static_shapes and generation_config.max_new_tokens is not None: calculated_max_length = input_ids.shape[-1] + generation_config.max_new_tokens if generation_config.use_cache and generation_config.reuse_cache: bs, _ = input_ids.shape + #import pdb; pdb.set_trace() if not is_greedy_or_beam_and_bucket: + #import pdb; pdb.set_trace() unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, @@ -1442,6 +1444,7 @@ def greedy_search( output_hidden_states=output_hidden_states, **hpu_graphs_kwargs, ) + #import pdb; pdb.set_trace() if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 384f033cd6..ac74640a7b 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -42,7 +42,10 @@ GaudiLlamaMLP, GaudiLlamaModel, GaudiLlamaRotaryEmbedding, + GaudiMistralAttention, + GaudiMistralDecoderLayer, GaudiMistralForCausalLM, + GaudiMistralModel, GaudiMixtralForCausalLM, GaudiMptForCausalLM, GaudiMptModel, @@ -98,9 +101,6 @@ gaudi_gptj_model_forward, gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, - gaudi_mistral_attention_forward, - gaudi_mistral_decoder_layer_forward, - gaudi_mistral_model_forward, gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_decoder_layer_forward, @@ -312,9 +312,10 @@ def adapt_transformers_to_gaudi(): # Optimization for mistral on Gaudi transformers.models.mistral.modeling_mistral.MistralForCausalLM = GaudiMistralForCausalLM - transformers.models.mistral.modeling_mistral.MistralAttention.forward = gaudi_mistral_attention_forward - transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = gaudi_mistral_decoder_layer_forward - transformers.models.mistral.modeling_mistral.MistralModel.forward = gaudi_mistral_model_forward + #import pdb; pdb.set_trace() + transformers.models.mistral.modeling_mistral.MistralAttention = GaudiMistralAttention + transformers.models.mistral.modeling_mistral.MistralDecoderLayer = GaudiMistralDecoderLayer + transformers.models.mistral.modeling_mistral.MistralModel = GaudiMistralModel # Optimization for blip Text model on Gaudi transformers.models.blip.BlipTextModel.forward = gaudi_BlipTextModel_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index a369627d2f..c56f5bf5f9 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -82,9 +82,9 @@ ) from .mistral import ( GaudiMistralForCausalLM, - gaudi_mistral_attention_forward, - gaudi_mistral_decoder_layer_forward, - gaudi_mistral_model_forward, + GaudiMistralDecoderLayer, + GaudiMistralAttention, + GaudiMistralModel ) from .mixtral import ( GaudiMixtralForCausalLM, diff --git a/optimum/habana/transformers/models/mistral/__init__.py b/optimum/habana/transformers/models/mistral/__init__.py index 83367a6e25..32523fe370 100644 --- a/optimum/habana/transformers/models/mistral/__init__.py +++ b/optimum/habana/transformers/models/mistral/__init__.py @@ -1,6 +1,6 @@ from .modeling_mistral import ( GaudiMistralForCausalLM, - gaudi_mistral_attention_forward, - gaudi_mistral_decoder_layer_forward, - gaudi_mistral_model_forward, + GaudiMistralDecoderLayer, + GaudiMistralAttention, + GaudiMistralModel ) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index a5035b6829..bab7697257 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -37,316 +37,430 @@ ) +from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.models.mistral.modeling_mistral import ( + MistralAttention, + MistralDecoderLayer, + MistralForCausalLM, + MistralModel, + MistralMLP, + MistralRMSNorm, + apply_rotary_pos_emb, + logger, +) + logger = logging.get_logger(__name__) +def update(prev, cur, dim, idx): + orig_cur = cur + if prev.shape == cur.shape: + # Initialize + prev.copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + return prev.index_copy_(dim, idx - 1, cur) + else: + return torch.cat((prev, cur), dim=dim) + + +class GaudiMistralAttention(MistralAttention): + def __init__(self, config: MistralConfig, layer_idx: int): + #import pdb; pdb.set_trace() + super().__init__(config) + self.past_key = None + self.past_value = None + self.layer_idx = layer_idx + + def allocate_kv_cache(self, batch_size, seq_len): + key_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) + value_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) + if self.past_key is None or self.past_key.shape != key_shape: + #if not hasattr(self, 'past_key') or self.past_key.shape != key_shape: + device = self.k_proj.weight.device + dtype = self.k_proj.weight.dtype + self.past_key = torch.empty(key_shape, dtype=dtype, device=device) + self.past_value = torch.empty(value_shape, dtype=dtype, device=device) + + def reorder(self, tensor, beam_idx, dim_a, dim_b): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.past_key is None: + #if not hasattr(self, 'past_key'): + return (None, None) + + head_dim = self.past_key.size(-1) + seq_length = self.past_key.size(-2) + self.reorder(self.past_key, beam_idx, seq_length, head_dim) + self.reorder(self.past_value, beam_idx, seq_length, head_dim) + return (self.past_key.shape, self.past_value.shape) -def gaudi_mistral_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Copied from MistralAttention.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - add new args token_idx - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from MistralAttention.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - add new args token_idx + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_shape = ( + (past_key_value[0][-2] if reuse_cache else past_key_value[0].shape[-2]) + if isinstance(past_key_value, tuple) + else past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + ) + if token_idx is not None: + kv_seq_len = kv_shape + else: + kv_seq_len += kv_shape + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None or reuse_cache: + if reuse_cache: + past_key = self.past_key + past_value = self.past_value + else: + past_key = past_key_value[0] + past_value = past_key_value[1] + key_states = update(past_key, key_states, 2, token_idx) + value_states = update(past_value, value_states, 2, token_idx) + #if token_idx is not None: + # past_key_value[0].index_copy_(2, token_idx - 1, key_states) + # past_key_value[1].index_copy_(2, token_idx - 1, value_states) + # key_states = past_key_value[0] + # value_states = past_key_value[1] + #else: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if use_cache: + if reuse_cache: + past_key_value = (key_states.contiguous().shape, value_states.contiguous().shape) + else: + past_key_value = (key_states.contiguous(), value_states.contiguous()) + else: + past_key_value = None - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - kv_shape = ( - past_key_value[0].shape[-2] - if isinstance(past_key_value, tuple) - else past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - ) - if token_idx is not None: - kv_seq_len = kv_shape - else: - kv_seq_len += kv_shape - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - if token_idx is not None: - past_key_value[0].index_copy_(2, token_idx - 1, key_states) - past_key_value[1].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value[0] - value_states = past_key_value[1] - else: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - past_key_value = (key_states, value_states) if use_cache else None - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = attn_weights + attention_mask - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" ) - attn_weights = attn_weights + attention_mask + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = self.o_proj(attn_output) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + if not output_attentions: + attn_weights = None + #import pdb; pdb.set_trace() - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def gaudi_mistral_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from MistralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - add new args token_idx - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + return attn_output, attn_weights, past_key_value - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -def gaudi_mistral_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - """ - Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - add new args token_idx - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." +class GaudiMistralDecoderLayer(MistralDecoderLayer): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + #import pdb; pdb.set_trace() + + self.self_attn = GaudiMistralAttention(config, layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + def allocate_kv_cache(self, batch_size, seq_len): + self.self_attn.allocate_kv_cache(batch_size, seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.self_attn.reorder_kv_cache(beam_idx) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from MistralDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - add new args token_idx + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - use_cache = False - - past_key_values_length = 0 - use_legacy_cache = True - use_new_cache = False - if past_key_values is not None: - if use_cache and use_new_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache ) + #import pdb; pdb.set_trace() + hidden_states = residual + hidden_states - hidden_states = inputs_embeds + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if not use_new_cache else None + outputs = (hidden_states,) - for layer_idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class GaudiMistralModel(MistralModel): + def allocate_kv_cache(self, batch_size, seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, seq_len) + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - add new args token_idx + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + use_legacy_cache = True + use_new_cache = False + if past_key_values is not None and use_cache: + if reuse_cache: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._attn_implementation == "sdpa" and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, - position_ids, - None if past_key_values is None else past_key_values[layer_idx], - output_attentions, - use_cache, - None, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=None if past_key_values is None else past_key_values[layer_idx], - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, + # 4d mask is passed through the layers + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, ) - hidden_states = layer_outputs[0] + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if not use_new_cache else None + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + None if past_key_values is None else past_key_values[layer_idx], + output_attentions, + use_cache, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache + ) - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) + if use_cache: + #import pdb; pdb.set_trace() + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - hidden_states = self.norm(hidden_states) + if output_attentions: + all_self_attns += (layer_outputs[1],) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + hidden_states = self.norm(hidden_states) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache - if not use_new_cache - else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache + if not use_new_cache + else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache) + ) + #import pdb; pdb.set_trace() + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, ) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) class GaudiMistralForCausalLM(MistralForCausalLM): + def allocate_kv_cache(self, batch_size, seq_len, _, __): + self.model.allocate_kv_cache(batch_size, seq_len) + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) def forward( self, input_ids: torch.LongTensor = None, @@ -360,6 +474,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -385,7 +500,9 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, ) + #import pdb; pdb.set_trace() hidden_states = outputs[0] logits = self.lm_head(hidden_states) @@ -397,11 +514,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens + loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Ensure tensors are on the same device + # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: @@ -486,6 +603,7 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": kwargs.get("reuse_cache"), } ) return model_inputs From 65ca028b89382ba94223df8ea4d884618b1da314 Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Thu, 14 Mar 2024 20:13:54 +0000 Subject: [PATCH 02/16] Add repeat-kv gaudi change. --- .../models/mistral/modeling_mistral.py | 63 +++++++++++++++---- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index bab7697257..d196bc26d7 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -62,6 +62,39 @@ def update(prev, cur, dim, idx): return prev.index_copy_(dim, idx - 1, cur) else: return torch.cat((prev, cur), dim=dim) +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def gaudi_mistral_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): + """ + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) + """ + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask class GaudiMistralAttention(MistralAttention): @@ -176,29 +209,35 @@ def forward( past_key_value = None # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" + query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + if attn_weights.size() not in [ + (bsz, self.num_heads, q_len, kv_seq_len), + (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), + ]: raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" + f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) + if attention_mask is not None: + if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( From 30e4e76f05c1d4d50a122c9df5736484ea3ee122 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 14 Mar 2024 21:08:10 +0000 Subject: [PATCH 03/16] trim logits --- .../transformers/models/mistral/modeling_mistral.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index d196bc26d7..c4dfcf920c 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -513,7 +513,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, - reuse_cache: Optional[bool] = False + reuse_cache: Optional[bool] = False, + trim_logits: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -544,6 +545,12 @@ def forward( #import pdb; pdb.set_trace() hidden_states = outputs[0] + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] logits = self.lm_head(hidden_states) logits = logits.float() @@ -643,6 +650,7 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, "token_idx": token_idx, "reuse_cache": kwargs.get("reuse_cache"), + "trim_logits": kwargs.get("trim_logits"), } ) return model_inputs From 0e67530906cf8ae72acc5d6ed13c8f567c863e22 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 14 Mar 2024 21:36:32 +0000 Subject: [PATCH 04/16] fusedrms --- optimum/habana/transformers/modeling_utils.py | 2 ++ .../habana/transformers/models/__init__.py | 3 +- .../transformers/models/mistral/__init__.py | 3 +- .../models/mistral/modeling_mistral.py | 28 +++++++++++++++++++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index ac74640a7b..4253765a53 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -106,6 +106,7 @@ gaudi_mixtral_decoder_layer_forward, gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, + gaudi_mistral_rmsnorm_forward, gaudi_mpt_attention_forward, gaudi_mpt_block_forward, gaudi_opt_attention_forward, @@ -316,6 +317,7 @@ def adapt_transformers_to_gaudi(): transformers.models.mistral.modeling_mistral.MistralAttention = GaudiMistralAttention transformers.models.mistral.modeling_mistral.MistralDecoderLayer = GaudiMistralDecoderLayer transformers.models.mistral.modeling_mistral.MistralModel = GaudiMistralModel + transformers.models.mistral.modeling_mistral.MistralRMSNorm.forward = gaudi_mistral_rmsnorm_forward # Optimization for blip Text model on Gaudi transformers.models.blip.BlipTextModel.forward = gaudi_BlipTextModel_forward diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index c56f5bf5f9..5563e89a39 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -84,7 +84,8 @@ GaudiMistralForCausalLM, GaudiMistralDecoderLayer, GaudiMistralAttention, - GaudiMistralModel + GaudiMistralModel, + gaudi_mistral_rmsnorm_forward, ) from .mixtral import ( GaudiMixtralForCausalLM, diff --git a/optimum/habana/transformers/models/mistral/__init__.py b/optimum/habana/transformers/models/mistral/__init__.py index 32523fe370..b5beaa9148 100644 --- a/optimum/habana/transformers/models/mistral/__init__.py +++ b/optimum/habana/transformers/models/mistral/__init__.py @@ -2,5 +2,6 @@ GaudiMistralForCausalLM, GaudiMistralDecoderLayer, GaudiMistralAttention, - GaudiMistralModel + GaudiMistralModel, + gaudi_mistral_rmsnorm_forward ) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index c4dfcf920c..b03ae85538 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -49,6 +49,12 @@ logger, ) +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm +except ImportError: + print("Not using HPU fused kernel for RMSNorm") + FusedRMSNorm = None + logger = logging.get_logger(__name__) def update(prev, cur, dim, idx): @@ -97,6 +103,28 @@ def gaudi_mistral_repeat_kv( return query_states, key_states, value_states, attention_mask +def gaudi_mistral_rmsnorm_forward(self, hidden_states): + """ + Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - override RMSNorm with Habana fused RMSNorm + """ + if hidden_states.device.type == "hpu" and FusedRMSNorm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: int): #import pdb; pdb.set_trace() From 3c580e28d5aa68078c00681ab45ec2cf28396bbd Mon Sep 17 00:00:00 2001 From: Libin Tang Date: Thu, 14 Mar 2024 22:24:52 +0000 Subject: [PATCH 05/16] Added bucket_internal, rotary embedding cache --- .../models/mistral/modeling_mistral.py | 44 +++++++++++++------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index b03ae85538..86b4341f3d 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -101,6 +101,13 @@ def gaudi_mistral_repeat_kv( attention_mask = attention_mask.unsqueeze(1) return query_states, key_states, value_states, attention_mask +def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) def gaudi_mistral_rmsnorm_forward(self, hidden_states): @@ -169,12 +176,15 @@ def forward( use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from MistralAttention.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args token_idx + - add new args reuse_cache + - add new args cache_idx """ if "padding_mask" in kwargs: warnings.warn( @@ -219,15 +229,6 @@ def forward( past_value = past_key_value[1] key_states = update(past_key, key_states, 2, token_idx) value_states = update(past_value, value_states, 2, token_idx) - #if token_idx is not None: - # past_key_value[0].index_copy_(2, token_idx - 1, key_states) - # past_key_value[1].index_copy_(2, token_idx - 1, value_states) - # key_states = past_key_value[0] - # value_states = past_key_value[1] - #else: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - if use_cache: if reuse_cache: past_key_value = (key_states.contiguous().shape, value_states.contiguous().shape) @@ -235,6 +236,11 @@ def forward( past_key_value = (key_states.contiguous(), value_states.contiguous()) else: past_key_value = None + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( @@ -304,6 +310,9 @@ def allocate_kv_cache(self, batch_size, seq_len): def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) + def update_sincos_cache(self, seq_len): + self.self_attn.update_sincos_cache(seq_len) + def forward( self, hidden_states: torch.Tensor, @@ -314,6 +323,7 @@ def forward( use_cache: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -339,7 +349,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, - reuse_cache=reuse_cache + reuse_cache=reuse_cache, + cache_idx=cache_idx ) #import pdb; pdb.set_trace() hidden_states = residual + hidden_states @@ -381,6 +392,7 @@ def forward( return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -487,7 +499,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, - reuse_cache=reuse_cache + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) hidden_states = layer_outputs[0] @@ -527,7 +540,9 @@ class GaudiMistralForCausalLM(MistralForCausalLM): def allocate_kv_cache(self, batch_size, seq_len, _, __): self.model.allocate_kv_cache(batch_size, seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - return self.model.reorder_kv_cache(beam_idx) + return self.model.reorder_kv_cache(beam_idx) + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) def forward( self, input_ids: torch.LongTensor = None, @@ -543,6 +558,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, + cache_idx: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -569,9 +585,8 @@ def forward( return_dict=return_dict, token_idx=token_idx, reuse_cache=reuse_cache, + cache_idx=cache_idx, ) - #import pdb; pdb.set_trace() - hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape if seq_len > 1 and trim_logits and not self.training: @@ -679,6 +694,7 @@ def prepare_inputs_for_generation( "token_idx": token_idx, "reuse_cache": kwargs.get("reuse_cache"), "trim_logits": kwargs.get("trim_logits"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs From c08a4e2a1cb5d4838cf58827a1e67d506ac54c9d Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 14 Mar 2024 23:49:01 +0000 Subject: [PATCH 06/16] attn softmax bf16 --- .../models/mistral/modeling_mistral.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 86b4341f3d..993e876c78 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -177,6 +177,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, + attn_softmax_bf16: Optional[bool] = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -267,8 +268,11 @@ def forward( attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + if attn_softmax_bf16: + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) @@ -324,6 +328,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, + attn_softmax_bf16: Optional[bool] = False, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -350,7 +355,8 @@ def forward( use_cache=use_cache, token_idx=token_idx, reuse_cache=reuse_cache, - cache_idx=cache_idx + cache_idx=cache_idx, + attn_softmax_bf16=attn_softmax_bf16 ) #import pdb; pdb.set_trace() hidden_states = residual + hidden_states @@ -393,6 +399,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, + attn_softmax_bf16: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -501,6 +508,7 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + attn_softmax_bf16=attn_softmax_bf16, ) hidden_states = layer_outputs[0] @@ -559,6 +567,7 @@ def forward( reuse_cache: Optional[bool] = False, trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, + attn_softmax_bf16: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -586,6 +595,7 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + attn_softmax_bf16=attn_softmax_bf16, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -695,6 +705,7 @@ def prepare_inputs_for_generation( "reuse_cache": kwargs.get("reuse_cache"), "trim_logits": kwargs.get("trim_logits"), "cache_idx": kwargs.get("cache_idx"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), } ) return model_inputs From 68659560d70ea48d02ad60abd4fe8a6225c81f25 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Fri, 15 Mar 2024 04:14:11 +0000 Subject: [PATCH 07/16] remove pdb --- optimum/habana/transformers/generation/utils.py | 4 ---- optimum/habana/transformers/modeling_utils.py | 1 - .../habana/transformers/models/mistral/modeling_mistral.py | 6 ------ 3 files changed, 11 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 785fbfdaab..2c77196d87 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -720,16 +720,13 @@ def generate( model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False model_kwargs["use_fused_rope"] = False if generation_config.use_fused_rope is False else True - #import pdb; pdb.set_trace() if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] if not generation_config.static_shapes and generation_config.max_new_tokens is not None: calculated_max_length = input_ids.shape[-1] + generation_config.max_new_tokens if generation_config.use_cache and generation_config.reuse_cache: bs, _ = input_ids.shape - #import pdb; pdb.set_trace() if not is_greedy_or_beam_and_bucket: - #import pdb; pdb.set_trace() unwrap_deepspeed_model(self).allocate_kv_cache( bs * generation_config.num_beams, calculated_max_length, @@ -1444,7 +1441,6 @@ def greedy_search( output_hidden_states=output_hidden_states, **hpu_graphs_kwargs, ) - #import pdb; pdb.set_trace() if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 4253765a53..c672160cbe 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -313,7 +313,6 @@ def adapt_transformers_to_gaudi(): # Optimization for mistral on Gaudi transformers.models.mistral.modeling_mistral.MistralForCausalLM = GaudiMistralForCausalLM - #import pdb; pdb.set_trace() transformers.models.mistral.modeling_mistral.MistralAttention = GaudiMistralAttention transformers.models.mistral.modeling_mistral.MistralDecoderLayer = GaudiMistralDecoderLayer transformers.models.mistral.modeling_mistral.MistralModel = GaudiMistralModel diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 993e876c78..0346d1a0e2 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -134,7 +134,6 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: int): - #import pdb; pdb.set_trace() super().__init__(config) self.past_key = None self.past_value = None @@ -290,7 +289,6 @@ def forward( if not output_attentions: attn_weights = None - #import pdb; pdb.set_trace() return attn_output, attn_weights, past_key_value @@ -299,7 +297,6 @@ class GaudiMistralDecoderLayer(MistralDecoderLayer): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config, layer_idx) self.hidden_size = config.hidden_size - #import pdb; pdb.set_trace() self.self_attn = GaudiMistralAttention(config, layer_idx) @@ -358,7 +355,6 @@ def forward( cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16 ) - #import pdb; pdb.set_trace() hidden_states = residual + hidden_states # Fully Connected @@ -514,7 +510,6 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - #import pdb; pdb.set_trace() next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: @@ -533,7 +528,6 @@ def forward( if not use_new_cache else (next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache) ) - #import pdb; pdb.set_trace() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( From 9c485a968b2a4875446af7521c449162d39fe3e4 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Fri, 15 Mar 2024 04:29:30 +0000 Subject: [PATCH 08/16] style --- .../training/textual_inversion.py | 6 +- optimum/habana/transformers/modeling_utils.py | 2 +- .../habana/transformers/models/__init__.py | 4 +- .../transformers/models/mistral/__init__.py | 6 +- .../models/mistral/modeling_mistral.py | 98 ++++++++++--------- 5 files changed, 61 insertions(+), 55 deletions(-) diff --git a/examples/stable-diffusion/training/textual_inversion.py b/examples/stable-diffusion/training/textual_inversion.py index f0169aebff..e185b45df4 100644 --- a/examples/stable-diffusion/training/textual_inversion.py +++ b/examples/stable-diffusion/training/textual_inversion.py @@ -886,9 +886,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( - orig_embeds_params[index_no_updates] - ) + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index c672160cbe..84750dd89a 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -101,12 +101,12 @@ gaudi_gptj_model_forward, gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, + gaudi_mistral_rmsnorm_forward, gaudi_mixtral_attention_forward, gaudi_mixtral_block_sparse_moe_forward, gaudi_mixtral_decoder_layer_forward, gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, - gaudi_mistral_rmsnorm_forward, gaudi_mpt_attention_forward, gaudi_mpt_block_forward, gaudi_opt_attention_forward, diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 5563e89a39..879e8333d8 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -81,9 +81,9 @@ gaudi_llama_rmsnorm_forward, ) from .mistral import ( - GaudiMistralForCausalLM, - GaudiMistralDecoderLayer, GaudiMistralAttention, + GaudiMistralDecoderLayer, + GaudiMistralForCausalLM, GaudiMistralModel, gaudi_mistral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mistral/__init__.py b/optimum/habana/transformers/models/mistral/__init__.py index b5beaa9148..192c267791 100644 --- a/optimum/habana/transformers/models/mistral/__init__.py +++ b/optimum/habana/transformers/models/mistral/__init__.py @@ -1,7 +1,7 @@ from .modeling_mistral import ( - GaudiMistralForCausalLM, - GaudiMistralDecoderLayer, GaudiMistralAttention, + GaudiMistralDecoderLayer, + GaudiMistralForCausalLM, GaudiMistralModel, - gaudi_mistral_rmsnorm_forward + gaudi_mistral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 0346d1a0e2..65e395b1f1 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -29,25 +29,22 @@ from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.mistral.modeling_mistral import MistralForCausalLM, apply_rotary_pos_emb, repeat_kv -from transformers.utils import logging - -from ...modeling_attn_mask_utils import ( - _gaudi_prepare_4d_causal_attention_mask, -) - - from transformers.models.mistral.configuration_mistral import MistralConfig from transformers.models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, MistralForCausalLM, - MistralModel, MistralMLP, + MistralModel, MistralRMSNorm, apply_rotary_pos_emb, - logger, ) +from transformers.utils import logging + +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) + try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm @@ -57,6 +54,7 @@ logger = logging.get_logger(__name__) + def update(prev, cur, dim, idx): orig_cur = cur if prev.shape == cur.shape: @@ -68,6 +66,8 @@ def update(prev, cur, dim, idx): return prev.index_copy_(dim, idx - 1, cur) else: return torch.cat((prev, cur), dim=dim) + + # Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( query_states: torch.Tensor, @@ -101,13 +101,15 @@ def gaudi_mistral_repeat_kv( attention_mask = attention_mask.unsqueeze(1) return query_states, key_states, value_states, attention_mask + + def update_sincos_cache(self, seq_len): - # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings - # This helps in avoiding creation of these caches during actual model forward pass and - # reduce memory consumption and improve performance. - if seq_len > self.max_position_embeddings: - self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) def gaudi_mistral_rmsnorm_forward(self, hidden_states): @@ -132,6 +134,7 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config) @@ -143,7 +146,7 @@ def allocate_kv_cache(self, batch_size, seq_len): key_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) value_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) if self.past_key is None or self.past_key.shape != key_shape: - #if not hasattr(self, 'past_key') or self.past_key.shape != key_shape: + # if not hasattr(self, 'past_key') or self.past_key.shape != key_shape: device = self.k_proj.weight.device dtype = self.k_proj.weight.dtype self.past_key = torch.empty(key_shape, dtype=dtype, device=device) @@ -155,7 +158,7 @@ def reorder(self, tensor, beam_idx, dim_a, dim_b): def reorder_kv_cache(self, beam_idx: torch.LongTensor): if self.past_key is None: - #if not hasattr(self, 'past_key'): + # if not hasattr(self, 'past_key'): return (None, None) head_dim = self.past_key.size(-1) @@ -164,7 +167,6 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor): self.reorder(self.past_value, beam_idx, seq_length, head_dim) return (self.past_key.shape, self.past_value.shape) - def forward( self, hidden_states: torch.Tensor, @@ -180,11 +182,11 @@ def forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ - Copied from MistralAttention.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - add new args token_idx - - add new args reuse_cache - - add new args cache_idx + Copied from MistralAttention.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - add new args token_idx + - add new args reuse_cache + - add new args cache_idx """ if "padding_mask" in kwargs: warnings.warn( @@ -237,33 +239,33 @@ def forward( else: past_key_value = None if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( - query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) if attn_weights.size() not in [ - (bsz, self.num_heads, q_len, kv_seq_len), - (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), - ]: - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" - f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) + (bsz, self.num_heads, q_len, kv_seq_len), + (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), + ]: + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" + f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) if attention_mask is not None: if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," - f" but is {attention_mask.size()}" - ) + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) attn_weights = attn_weights + attention_mask @@ -304,7 +306,6 @@ def __init__(self, config: MistralConfig, layer_idx: int): self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, seq_len): self.self_attn.allocate_kv_cache(batch_size, seq_len) @@ -353,7 +354,7 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, - attn_softmax_bf16=attn_softmax_bf16 + attn_softmax_bf16=attn_softmax_bf16, ) hidden_states = residual + hidden_states @@ -378,6 +379,7 @@ class GaudiMistralModel(MistralModel): def allocate_kv_cache(self, batch_size, seq_len): for layer in self.layers: layer.allocate_kv_cache(batch_size, seq_len) + def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -432,7 +434,8 @@ def forward( use_new_cache = False if past_key_values is not None and use_cache: if reuse_cache: - past_seen_tokens = past_key_values[0][0][2] + # past_seen_tokens = past_key_values[0][0][2] + pass else: if use_new_cache: use_legacy_cache = not isinstance(past_key_values, Cache) @@ -541,10 +544,13 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): def allocate_kv_cache(self, batch_size, seq_len, _, __): self.model.allocate_kv_cache(batch_size, seq_len) + def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) + def update_sincos_cache(self, seq_len): - self.model.update_sincos_cache(seq_len) + self.model.update_sincos_cache(seq_len) + def forward( self, input_ids: torch.LongTensor = None, From 61ed38fe1beb4b0b123295640c65330aa99cff26 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Fri, 15 Mar 2024 05:02:53 +0000 Subject: [PATCH 09/16] restore unrelated file to original --- examples/stable-diffusion/training/textual_inversion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/stable-diffusion/training/textual_inversion.py b/examples/stable-diffusion/training/textual_inversion.py index e185b45df4..f0169aebff 100644 --- a/examples/stable-diffusion/training/textual_inversion.py +++ b/examples/stable-diffusion/training/textual_inversion.py @@ -886,9 +886,9 @@ def main(): index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: From f71bf61fc71cb30a8eab21be170062a313e0911f Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Mon, 18 Mar 2024 10:44:45 -0700 Subject: [PATCH 10/16] Update utils.py assert message --- optimum/habana/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 2c77196d87..41b70dfecc 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -577,7 +577,7 @@ def generate( if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size if generation_config.reuse_cache: - assert self.config.model_type in ["llama", "mistral"], "reuse_cache only supported by llama at the moment" + assert self.config.model_type in ["llama", "mistral"], "reuse_cache only supported by llama and mistral at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 From a48a726bf3e3b53ed5ab66ae465fe4234ef51089 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Mon, 18 Mar 2024 10:47:38 -0700 Subject: [PATCH 11/16] Add markstep in the middle of the model for larger BS --- .../habana/transformers/models/mistral/modeling_mistral.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 65e395b1f1..c5081fd86f 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -44,7 +44,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) - +import habana_frameworks.torch.core as htcore try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm @@ -482,6 +482,8 @@ def forward( next_decoder_cache = () if not use_new_cache else None for layer_idx, decoder_layer in enumerate(self.layers): + if layer_idx == len(self.layers)//2: + htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,) From 805907776d99ab9c1ac580da794d2bc6f4d112fd Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Mon, 18 Mar 2024 18:02:00 +0000 Subject: [PATCH 12/16] style --- optimum/habana/transformers/generation/utils.py | 5 ++++- .../habana/transformers/models/mistral/modeling_mistral.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d660066107..bad54641d6 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -578,7 +578,10 @@ def generate( if model_kwargs["reduce_recompile"]: assert generation_config.bucket_size if generation_config.reuse_cache: - assert self.config.model_type in ["llama", "mistral"], "reuse_cache only supported by llama and mistral at the moment" + assert self.config.model_type in [ + "llama", + "mistral", + ], "reuse_cache only supported by llama and mistral at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index c5081fd86f..aeba7fe505 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -23,6 +23,7 @@ import warnings from typing import List, Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -44,7 +45,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -import habana_frameworks.torch.core as htcore + try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm @@ -482,7 +483,7 @@ def forward( next_decoder_cache = () if not use_new_cache else None for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx == len(self.layers)//2: + if layer_idx == len(self.layers) // 2: htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,) From d8cdc75d8168a00398d151dac178c5c89ace89bc Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 19 Mar 2024 18:33:44 +0000 Subject: [PATCH 13/16] address comments --- .../models/mistral/modeling_mistral.py | 99 ++++++++----------- 1 file changed, 41 insertions(+), 58 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index aeba7fe505..4a5edcf218 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -69,6 +69,30 @@ def update(prev, cur, dim, idx): return torch.cat((prev, cur), dim=dim) + + +def gaudi_mistral_rmsnorm_forward(self, hidden_states): + """ + Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - override RMSNorm with Habana fused RMSNorm + """ + if hidden_states.device.type == "hpu" and FusedRMSNorm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + # Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( query_states: torch.Tensor, @@ -103,56 +127,29 @@ def gaudi_mistral_repeat_kv( return query_states, key_states, value_states, attention_mask - -def update_sincos_cache(self, seq_len): - # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings - # This helps in avoiding creation of these caches during actual model forward pass and - # reduce memory consumption and improve performance. - if seq_len > self.max_position_embeddings: - self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) - - -def gaudi_mistral_rmsnorm_forward(self, hidden_states): - """ - Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - override RMSNorm with Habana fused RMSNorm - """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: - # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype - if hidden_states.dtype != self.weight.dtype: - orig_dtype = hidden_states.dtype - hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) - return hidden_states.to(orig_dtype) - else: - hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) - return hidden_states - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - class GaudiMistralAttention(MistralAttention): - def __init__(self, config: MistralConfig, layer_idx: int): - super().__init__(config) + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) self.past_key = None self.past_value = None - self.layer_idx = layer_idx def allocate_kv_cache(self, batch_size, seq_len): key_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - value_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) + value_shape = key_shape if self.past_key is None or self.past_key.shape != key_shape: - # if not hasattr(self, 'past_key') or self.past_key.shape != key_shape: device = self.k_proj.weight.device dtype = self.k_proj.weight.dtype self.past_key = torch.empty(key_shape, dtype=dtype, device=device) self.past_value = torch.empty(value_shape, dtype=dtype, device=device) + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) tensor.copy_(updated) @@ -299,14 +296,8 @@ def forward( class GaudiMistralDecoderLayer(MistralDecoderLayer): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config, layer_idx) - self.hidden_size = config.hidden_size - self.self_attn = GaudiMistralAttention(config, layer_idx) - self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def allocate_kv_cache(self, batch_size, seq_len): self.self_attn.allocate_kv_cache(batch_size, seq_len) @@ -335,10 +326,6 @@ def forward( The only differences are: - add new args token_idx """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states @@ -433,16 +420,12 @@ def forward( past_key_values_length = 0 use_legacy_cache = True use_new_cache = False - if past_key_values is not None and use_cache: - if reuse_cache: - # past_seen_tokens = past_key_values[0][0][2] - pass - else: - if use_new_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if past_key_values is not None and use_cache and not reuse_cache: + if use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device From b4e6d2d7a972f1fbd59709f25b443c13a7fcce1b Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Wed, 20 Mar 2024 00:53:37 +0200 Subject: [PATCH 14/16] address comments --- .../habana/transformers/models/mistral/modeling_mistral.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 4a5edcf218..64d05e039f 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -93,7 +93,6 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -186,10 +185,6 @@ def forward( - add new args reuse_cache - add new args cache_idx """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -602,7 +597,7 @@ def forward( loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Enable model parallelism + # Ensure tensors are on the same device shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) From 78280d7545bdfe350ca0f6b5948bdff3fcf5d881 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 19 Mar 2024 22:56:01 +0000 Subject: [PATCH 15/16] style --- .../habana/transformers/models/mistral/modeling_mistral.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 64d05e039f..f500ebff1e 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -20,7 +20,6 @@ """PyTorch Mistral model.""" import math -import warnings from typing import List, Optional, Tuple, Union import habana_frameworks.torch.core as htcore @@ -35,9 +34,7 @@ MistralAttention, MistralDecoderLayer, MistralForCausalLM, - MistralMLP, MistralModel, - MistralRMSNorm, apply_rotary_pos_emb, ) from transformers.utils import logging @@ -69,8 +66,6 @@ def update(prev, cur, dim, idx): return torch.cat((prev, cur), dim=dim) - - def gaudi_mistral_rmsnorm_forward(self, hidden_states): """ Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py @@ -93,6 +88,7 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) + def gaudi_mistral_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -126,6 +122,7 @@ def gaudi_mistral_repeat_kv( return query_states, key_states, value_states, attention_mask + class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) From 17b3ff84d2dba87d266b255a432e82a46e1b814b Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Thu, 21 Mar 2024 22:32:23 +0000 Subject: [PATCH 16/16] address comments --- .../transformers/models/mistral/modeling_mistral.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index f500ebff1e..bda8738516 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -130,13 +130,12 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.past_value = None def allocate_kv_cache(self, batch_size, seq_len): - key_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - value_shape = key_shape - if self.past_key is None or self.past_key.shape != key_shape: + kv_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) + if self.past_key is None or self.past_key.shape != kv_shape: device = self.k_proj.weight.device dtype = self.k_proj.weight.dtype - self.past_key = torch.empty(key_shape, dtype=dtype, device=device) - self.past_value = torch.empty(value_shape, dtype=dtype, device=device) + self.past_key = torch.empty(kv_shape, dtype=dtype, device=device) + self.past_value = torch.empty(kv_shape, dtype=dtype, device=device) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings @@ -152,7 +151,6 @@ def reorder(self, tensor, beam_idx, dim_a, dim_b): def reorder_kv_cache(self, beam_idx: torch.LongTensor): if self.past_key is None: - # if not hasattr(self, 'past_key'): return (None, None) head_dim = self.past_key.size(-1) @@ -591,11 +589,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Ensure tensors are on the same device shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: