From 254929092599b3048cb1c5495f1510967ae74c10 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Tue, 23 Apr 2024 10:49:34 -0700 Subject: [PATCH 01/12] add fp8 related changes to mistral for text-generation --- .../models/mistral/modeling_mistral.py | 261 +++++++++++------- 1 file changed, 164 insertions(+), 97 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index cf5fa6f2c0..850c31faf4 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -20,9 +20,9 @@ """PyTorch Mistral model.""" import math +import warnings from typing import List, Optional, Tuple, Union -import habana_frameworks.torch.core as htcore import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -34,15 +34,25 @@ MistralAttention, MistralDecoderLayer, MistralForCausalLM, + MistralMLP, MistralModel, + MistralRMSNorm, apply_rotary_pos_emb, ) from transformers.utils import logging - +from optimum.habana.transformers.models.modeling_all_models import KVCache from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) +import habana_frameworks.torch.core as htcore + +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm @@ -50,45 +60,16 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None -logger = logging.get_logger(__name__) - - -def update(prev, cur, dim, idx): - orig_cur = cur - if prev.shape == cur.shape: - # Initialize - prev.copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - return prev.index_copy_(dim, idx - 1, cur) - else: - return torch.cat((prev, cur), dim=dim) +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + def forward(self, x, y): + return torch.matmul(x, y) -def gaudi_mistral_rmsnorm_forward(self, hidden_states): - """ - Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - override RMSNorm with Habana fused RMSNorm - """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: - # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype - if hidden_states.dtype != self.weight.dtype: - orig_dtype = hidden_states.dtype - hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) - return hidden_states.to(orig_dtype) - else: - hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) - return hidden_states - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - +logger = logging.get_logger(__name__) +# Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -123,41 +104,71 @@ def gaudi_mistral_repeat_kv( return query_states, key_states, value_states, attention_mask -class GaudiMistralAttention(MistralAttention): - def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): - super().__init__(config, layer_idx) - self.past_key = None - self.past_value = None +def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) - def allocate_kv_cache(self, batch_size, seq_len): - kv_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - if self.past_key is None or self.past_key.shape != kv_shape: - device = self.k_proj.weight.device - dtype = self.k_proj.weight.dtype - self.past_key = torch.empty(kv_shape, dtype=dtype, device=device) - self.past_value = torch.empty(kv_shape, dtype=dtype, device=device) - def update_sincos_cache(self, seq_len): - # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings - # This helps in avoiding creation of these caches during actual model forward pass and - # reduce memory consumption and improve performance. - if seq_len > self.max_position_embeddings: - self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) +def gaudi_mistral_rmsnorm_forward(self, hidden_states): + """ + Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - override RMSNorm with Habana fused RMSNorm + """ + if hidden_states.device.type == "hpu" and FusedRMSNorm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +class GaudiMistralAttention(MistralAttention): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__(config) + self.k_cache = KVCache() + self.v_cache = KVCache() + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + # TODO: replace these two + #self.past_key = None + #self.past_value = None + self.layer_idx = layer_idx + self.inp_seq_len = -1 + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - if self.past_key is None: + if self.k_cache.cache is None: return (None, None) - head_dim = self.past_key.size(-1) - seq_length = self.past_key.size(-2) - self.reorder(self.past_key, beam_idx, seq_length, head_dim) - self.reorder(self.past_value, beam_idx, seq_length, head_dim) - return (self.past_key.shape, self.past_value.shape) + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) def forward( self, @@ -171,6 +182,7 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -180,6 +192,10 @@ def forward( - add new args reuse_cache - add new args cache_idx """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -208,35 +224,38 @@ def forward( else: kv_seq_len += kv_shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, use_fused_rope=use_fused_rope + ) - if past_key_value is not None or reuse_cache: - if reuse_cache: - past_key = self.past_key - past_value = self.past_value - else: - past_key = past_key_value[0] - past_value = past_key_value[1] - key_states = update(past_key, key_states, 2, token_idx) - value_states = update(past_value, value_states, 2, token_idx) if use_cache: + # reuse k, v, self_attention if reuse_cache: - past_key_value = (key_states.contiguous().shape, value_states.contiguous().shape) + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - past_key_value = (key_states.contiguous(), value_states.contiguous()) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] else: past_key_value = None - if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) if attn_weights.size() not in [ (bsz, self.num_heads, q_len, kv_seq_len), @@ -263,7 +282,7 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = self.matmul_av(attn_weights, value_states) attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -286,10 +305,16 @@ def forward( class GaudiMistralDecoderLayer(MistralDecoderLayer): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.self_attn = GaudiMistralAttention(config, layer_idx) - def allocate_kv_cache(self, batch_size, seq_len): - self.self_attn.allocate_kv_cache(batch_size, seq_len) + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -309,6 +334,7 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -316,6 +342,10 @@ def forward( The only differences are: - add new args token_idx """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states @@ -333,6 +363,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, ) hidden_states = residual + hidden_states @@ -354,9 +385,9 @@ def forward( class GaudiMistralModel(MistralModel): - def allocate_kv_cache(self, batch_size, seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, seq_len) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -376,11 +407,14 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args token_idx + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -410,12 +444,16 @@ def forward( past_key_values_length = 0 use_legacy_cache = True use_new_cache = False - if past_key_values is not None and use_cache and not reuse_cache: - if use_new_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if past_key_values is not None and use_cache: + if reuse_cache: + # past_seen_tokens = past_key_values[0][0][2] + pass + else: + if use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -455,8 +493,13 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx == len(self.layers) // 2: + if layer_idx == len(self.layers)//2 or \ + (lazy_mode and not self.training and \ + (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)): htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,) @@ -471,6 +514,7 @@ def forward( output_attentions, use_cache, None, + use_fused_rope, ) else: layer_outputs = decoder_layer( @@ -484,6 +528,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -518,8 +563,8 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): - def allocate_kv_cache(self, batch_size, seq_len, _): - self.model.allocate_kv_cache(batch_size, seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) @@ -544,6 +589,8 @@ def forward( trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -572,6 +619,8 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -589,11 +638,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens + loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Ensure tensors are on the same device + # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: @@ -682,6 +731,24 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "cache_idx": kwargs.get("cache_idx"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs + +def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): + if q.device.type == "hpu" and has_fused_rope and use_fused_rope: + # TODO: remove `.clone()` when SynapseAI v1.15 is released + if k.dtype==torch.bfloat16: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), position_ids + ) + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids) From 513aa10a759faf336dec3be642bdd6e326b22cc9 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Tue, 23 Apr 2024 10:55:08 -0700 Subject: [PATCH 02/12] add KVCache object --- .../models/mistral/modeling_mistral.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 850c31faf4..d741f9100a 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -40,7 +40,6 @@ apply_rotary_pos_emb, ) from transformers.utils import logging -from optimum.habana.transformers.models.modeling_all_models import KVCache from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) @@ -60,6 +59,46 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + class Matmul(torch.nn.Module): def __init__(self): super().__init__() From be79f0f59df5d9adee4e90b1d668b556dc481602 Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Tue, 23 Apr 2024 16:15:18 -0700 Subject: [PATCH 03/12] Fix layer_idx warning issue --- .../habana/transformers/models/mistral/modeling_mistral.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index d741f9100a..25301ab695 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -176,8 +176,8 @@ def gaudi_mistral_rmsnorm_forward(self, hidden_states): class GaudiMistralAttention(MistralAttention): - def __init__(self, config: MistralConfig, layer_idx: int): - super().__init__(config) + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) self.k_cache = KVCache() self.v_cache = KVCache() self.matmul_qk = Matmul() @@ -185,7 +185,6 @@ def __init__(self, config: MistralConfig, layer_idx: int): # TODO: replace these two #self.past_key = None #self.past_value = None - self.layer_idx = layer_idx self.inp_seq_len = -1 def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): From e955c2016ff11e12fe90a68c5e75d122ab7a408c Mon Sep 17 00:00:00 2001 From: Jimin Ha Date: Wed, 24 Apr 2024 09:06:05 -0700 Subject: [PATCH 04/12] Style fix --- .../models/mistral/modeling_mistral.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 25301ab695..9f7336d73c 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -23,6 +23,7 @@ import warnings from typing import List, Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -40,10 +41,11 @@ apply_rotary_pos_emb, ) from transformers.utils import logging + from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -import habana_frameworks.torch.core as htcore + try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -99,6 +101,7 @@ def get_shape(self): def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + class Matmul(torch.nn.Module): def __init__(self): super().__init__() @@ -106,8 +109,10 @@ def __init__(self): def forward(self, x, y): return torch.matmul(x, y) + logger = logging.get_logger(__name__) + # Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( query_states: torch.Tensor, @@ -182,9 +187,9 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): self.v_cache = KVCache() self.matmul_qk = Matmul() self.matmul_av = Matmul() - # TODO: replace these two - #self.past_key = None - #self.past_value = None + # TODO: replace these two + # self.past_key = None + # self.past_value = None self.inp_seq_len = -1 def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): @@ -275,7 +280,9 @@ def forward( else: if past_key_value is None: past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) - past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) past_key_value = (past_key, past_value) key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) @@ -535,9 +542,11 @@ def forward( htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx == len(self.layers)//2 or \ - (lazy_mode and not self.training and \ - (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)): + if layer_idx == len(self.layers) // 2 or ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,) @@ -774,14 +783,18 @@ def prepare_inputs_for_generation( ) return model_inputs + def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): if q.device.type == "hpu" and has_fused_rope and use_fused_rope: # TODO: remove `.clone()` when SynapseAI v1.15 is released - if k.dtype==torch.bfloat16: + if k.dtype == torch.bfloat16: return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), position_ids + k, + cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + position_ids, ) return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids From ffaaa1da4d7c0a1030e68096aedc56f72786722d Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Wed, 24 Apr 2024 19:53:41 +0000 Subject: [PATCH 05/12] add reuse_cache and some other arguments to mistral inputs --- examples/text-generation/run_lm_eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index b086c80b92..72539dadd2 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,13 +75,13 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type in ["llama", "falcon"]: + if self.model.config.model_type in ["llama", "mistral", "falcon"]: self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, } ) - if self.model.config.model_type == "llama": + if self.model.config.model_type in ["llama", "mistral"]: self.model_inputs.update( { "attn_softmax_bf16": self.options.attn_softmax_bf16, From 0ee033994d83c093f69ecf4e12d65a2f4425fc73 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Thu, 25 Apr 2024 17:48:07 +0000 Subject: [PATCH 06/12] style reformat --- optimum/habana/transformers/models/mistral/modeling_mistral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 9f7336d73c..975b89056f 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -61,6 +61,7 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None + class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() From dc7afb494fd1035c4f371b8486c29c8244e3385f Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Thu, 25 Apr 2024 11:40:04 -0700 Subject: [PATCH 07/12] Update modeling_mistral.py remove padding_mask warning --- .../transformers/models/mistral/modeling_mistral.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 975b89056f..0235d101cc 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -236,10 +236,6 @@ def forward( - add new args reuse_cache - add new args cache_idx """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -388,11 +384,6 @@ def forward( The only differences are: - add new args token_idx """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) From 2e20a3f0b3b2a2b01770dab50bc21cb61a0345f9 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Thu, 25 Apr 2024 18:51:54 +0000 Subject: [PATCH 08/12] style fix --- optimum/habana/transformers/models/mistral/modeling_mistral.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 0235d101cc..fa3069e85c 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -20,7 +20,6 @@ """PyTorch Mistral model.""" import math -import warnings from typing import List, Optional, Tuple, Union import habana_frameworks.torch.core as htcore From d7874dfa50e05480dfae66a68385f9682b933c8b Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Thu, 25 Apr 2024 22:36:23 +0000 Subject: [PATCH 09/12] add fp8 CI tests for Mistral --- tests/test_text_generation_example.py | 45 +++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index c8323f9588..dc4f9068c6 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -193,3 +193,48 @@ def test_text_generation_torch_compile(model_name: str, baseline: float, token: def test_text_generation_torch_compile_distributed(model_name: str, baseline: float, token: str): world_size = 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) + ("mistralai/Mistral-7B-Instruct-v0.2", 0.0), + +MISTRAL_FP8_CONFIG = { + "mistralai/Mistral-7B-Instruct-v0.2" = [ + ("896", "128", "128", 13310.566520719813), + ("120", "128", "2048", 7757.383448024244), + ("120", "2048", "128", 1352.070452897798 ), + ("44", "2048", "2048", 3101.5205518843136), + ], +} + + if "Mistral" in model_name: + command += [ + "--attn_softmax_bf16", + ] + command.remove("--max_new_tokens 100") + if "Mistral" in model_name: + command.insert(-2, "--limit_hpu_graphs") + command.insert(-2, "--max_input_tokens 1") + command.insert(-2, "--max_new_tokens 1") + command = [x for y in command for x in re.split(pattern, y) if x] + for model_config in MISTRAL_FP8_CONFIG[model_name]: + command[command.index("--batch_size") + 1] = model_config[0] + command[command.index("--max_input_tokens") + 1] = model_config[1] + command[command.index("--max_new_tokens") + 1] = model_config[2] + baseline = model_config[3] + proc = subprocess.run(command, env=env_variables) + + # Ensure the run finished without any issue + # Use try-except to avoid logging the token if used + try: + assert proc.returncode == 0 + except AssertionError as e: + if "'--token', 'hf_" in e.args[0]: + e.args = (f"The following command failed:\n{' '.join(command[:-2])}",) + raise + + with open(Path(tmp_dir) / "results.json") as fp: + results = json.load(fp) + + # Ensure performance requirements (throughput) are met + assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline + return + + From 717fd7d12520b95a8186210f139475498849e0d4 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Fri, 26 Apr 2024 19:55:11 +0000 Subject: [PATCH 10/12] . --- tests/test_text_generation_example.py | 91 ++++++++++++++------------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index dc4f9068c6..90852c04ee 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -33,7 +33,8 @@ ("codellama/CodeLlama-34b-hf", 1, True, 32.644), ], "fp8": [ - ("tiiuae/falcon-180B", 52.85086442722326), + #("tiiuae/falcon-180B", 52.85086442722326), + ("mistralai/Mistral-7B-Instruct-v0.2", 0), ], "deepspeed": [ ("bigscience/bloomz", 36.77314954096159), @@ -75,6 +76,15 @@ } +MISTRAL_FP8_CONFIG = { + "mistralai/Mistral-7B-Instruct-v0.2": [ + ("896", "128", "128", 13310.566520719813), + ("120", "128", "2048", 7757.383448024244), + ("120", "2048", "128", 1352.070452897798 ), + ("44", "2048", "2048", 3101.5205518843136), + ], +} + def _test_text_generation( model_name: str, baseline: float, @@ -129,6 +139,12 @@ def _test_text_generation( "--trim_logits", ] + if "Mistral" in model_name: + command += [ + "--attn_softmax_bf16", + ] + command.remove("--max_new_tokens 100") + with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") print(f"\n\nCommand to test: {' '.join(command)}\n") @@ -148,6 +164,35 @@ def _test_text_generation( ) command.insert(-2, "--fp8") + if "Mistral" in model_name: + command.insert(-2, "--limit_hpu_graphs") + command.insert(-2, "--max_input_tokens 1") + command.insert(-2, "--max_new_tokens 1") + command = [x for y in command for x in re.split(pattern, y) if x] + for model_config in MISTRAL_FP8_CONFIG[model_name]: + command[command.index("--batch_size") + 1] = model_config[0] + command[command.index("--max_input_tokens") + 1] = model_config[1] + command[command.index("--max_new_tokens") + 1] = model_config[2] + baseline = model_config[3] + proc = subprocess.run(command, env=env_variables) + + # Ensure the run finished without any issue + # Use try-except to avoid logging the token if used + try: + assert proc.returncode == 0 + except AssertionError as e: + if "'--token', 'hf_" in e.args[0]: + e.args = (f"The following command failed:\n{' '.join(command[:-2])}",) + raise + + with open(Path(tmp_dir) / "results.json") as fp: + results = json.load(fp) + + # Ensure performance requirements (throughput) are met + assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline + return + + proc = subprocess.run(command, env=env_variables) # Ensure the run finished without any issue @@ -193,48 +238,4 @@ def test_text_generation_torch_compile(model_name: str, baseline: float, token: def test_text_generation_torch_compile_distributed(model_name: str, baseline: float, token: str): world_size = 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) - ("mistralai/Mistral-7B-Instruct-v0.2", 0.0), - -MISTRAL_FP8_CONFIG = { - "mistralai/Mistral-7B-Instruct-v0.2" = [ - ("896", "128", "128", 13310.566520719813), - ("120", "128", "2048", 7757.383448024244), - ("120", "2048", "128", 1352.070452897798 ), - ("44", "2048", "2048", 3101.5205518843136), - ], -} - - if "Mistral" in model_name: - command += [ - "--attn_softmax_bf16", - ] - command.remove("--max_new_tokens 100") - if "Mistral" in model_name: - command.insert(-2, "--limit_hpu_graphs") - command.insert(-2, "--max_input_tokens 1") - command.insert(-2, "--max_new_tokens 1") - command = [x for y in command for x in re.split(pattern, y) if x] - for model_config in MISTRAL_FP8_CONFIG[model_name]: - command[command.index("--batch_size") + 1] = model_config[0] - command[command.index("--max_input_tokens") + 1] = model_config[1] - command[command.index("--max_new_tokens") + 1] = model_config[2] - baseline = model_config[3] - proc = subprocess.run(command, env=env_variables) - - # Ensure the run finished without any issue - # Use try-except to avoid logging the token if used - try: - assert proc.returncode == 0 - except AssertionError as e: - if "'--token', 'hf_" in e.args[0]: - e.args = (f"The following command failed:\n{' '.join(command[:-2])}",) - raise - - with open(Path(tmp_dir) / "results.json") as fp: - results = json.load(fp) - - # Ensure performance requirements (throughput) are met - assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline - return - From 93af21870171476cc9cbc5315398ffc32f75b9ca Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Fri, 26 Apr 2024 22:19:30 +0000 Subject: [PATCH 11/12] only change the tests, no other files --- examples/text-generation/run_lm_eval.py | 7 +- .../models/mistral/modeling_mistral.py | 301 +++++------------- 2 files changed, 77 insertions(+), 231 deletions(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 3ea74a6a69..b086c80b92 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,13 +75,13 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type in ["llama", "mistral", "falcon"]: + if self.model.config.model_type in ["llama", "falcon"]: self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, } ) - if self.model.config.model_type in ["llama", "mistral"]: + if self.model.config.model_type == "llama": self.model_inputs.update( { "attn_softmax_bf16": self.options.attn_softmax_bf16, @@ -152,8 +152,7 @@ def main(): model, tokenizer, generation_config = initialize_model(args, logger) lm_tasks = lm_eval.tasks.get_task_dict(args.tasks) - with torch.no_grad(): - lm = HabanaModelAdapter(tokenizer, model, args, generation_config) + lm = HabanaModelAdapter(tokenizer, model, args, generation_config) eval_start = time.perf_counter() results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit_iters) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index b4ab8a3258..cf5fa6f2c0 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -34,9 +34,7 @@ MistralAttention, MistralDecoderLayer, MistralForCausalLM, - MistralMLP, MistralModel, - MistralRMSNorm, apply_rotary_pos_emb, ) from transformers.utils import logging @@ -44,20 +42,7 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -from ..llama.modeling_llama import ( - GaudiLlamaDynamicNTKScalingRotaryEmbedding, - GaudiLlamaLinearScalingRotaryEmbedding, - GaudiLlamaRotaryEmbedding, -) - - -try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE - has_fused_rope = True -except ImportError: - has_fused_rope = False - print("Not using HPU fused kernel for apply_rotary_pos_emb") try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm @@ -65,60 +50,45 @@ print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None +logger = logging.get_logger(__name__) -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - def forward(self, x, y): - return torch.matmul(x, y) +def update(prev, cur, dim, idx): + orig_cur = cur + if prev.shape == cur.shape: + # Initialize + prev.copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + return prev.index_copy_(dim, idx - 1, cur) + else: + return torch.cat((prev, cur), dim=dim) -logger = logging.get_logger(__name__) +def gaudi_mistral_rmsnorm_forward(self, hidden_states): + """ + Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - override RMSNorm with Habana fused RMSNorm + """ + if hidden_states.device.type == "hpu" and FusedRMSNorm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) -# Copied from transformers.models.llama.modeling_llama.repeat_kv def gaudi_mistral_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -153,79 +123,11 @@ def gaudi_mistral_repeat_kv( return query_states, key_states, value_states, attention_mask -def update_sincos_cache(self, seq_len): - # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings - # This helps in avoiding creation of these caches during actual model forward pass and - # reduce memory consumption and improve performance. - if seq_len > self.max_position_embeddings: - self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) - - -def gaudi_mistral_rmsnorm_forward(self, hidden_states): - """ - Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - override RMSNorm with Habana fused RMSNorm - """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: - # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype - if hidden_states.dtype != self.weight.dtype: - orig_dtype = hidden_states.dtype - hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) - return hidden_states.to(orig_dtype) - else: - hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) - return hidden_states - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.k_cache = KVCache() - self.v_cache = KVCache() - self.matmul_qk = Matmul() - self.matmul_av = Matmul() - self.inp_seq_len = -1 self.past_key = None self.past_value = None - self._init_rope() - - def _init_rope(self): - """ - Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L294 - """ - if self.config.rope_scaling is None: - self.rotary_emb = GaudiLlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = GaudiLlamaLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = GaudiLlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def allocate_kv_cache(self, batch_size, seq_len): kv_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) @@ -248,14 +150,14 @@ def reorder(self, tensor, beam_idx, dim_a, dim_b): tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - if self.k_cache.cache is None: + if self.past_key is None: return (None, None) - head_dim = self.k_cache.cache.size(-1) - seq_length = self.k_cache.cache.size(-2) - self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) - self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) - return (self.k_cache.cache.shape, self.v_cache.cache.shape) + head_dim = self.past_key.size(-1) + seq_length = self.past_key.size(-2) + self.reorder(self.past_key, beam_idx, seq_length, head_dim) + self.reorder(self.past_value, beam_idx, seq_length, head_dim) + return (self.past_key.shape, self.past_value.shape) def forward( self, @@ -269,7 +171,6 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -307,40 +208,35 @@ def forward( else: kv_seq_len += kv_shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope( - query_states, key_states, cos, sin, position_ids, use_fused_rope=use_fused_rope - ) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + if past_key_value is not None or reuse_cache: + if reuse_cache: + past_key = self.past_key + past_value = self.past_value + else: + past_key = past_key_value[0] + past_value = past_key_value[1] + key_states = update(past_key, key_states, 2, token_idx) + value_states = update(past_value, value_states, 2, token_idx) if use_cache: - # reuse k, v, self_attention if reuse_cache: - key_states = self.k_cache(key_states, 2, token_idx) - value_states = self.v_cache(value_states, 2, token_idx) - past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + past_key_value = (key_states.contiguous().shape, value_states.contiguous().shape) else: - if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) - past_value = torch.zeros( - key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device - ) - past_key_value = (past_key, past_value) - key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) - - if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] + past_key_value = (key_states.contiguous(), value_states.contiguous()) else: past_key_value = None + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) if attn_weights.size() not in [ (bsz, self.num_heads, q_len, kv_seq_len), @@ -367,7 +263,7 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) + attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -390,16 +286,10 @@ def forward( class GaudiMistralDecoderLayer(MistralDecoderLayer): def __init__(self, config: MistralConfig, layer_idx: int): super().__init__(config, layer_idx) - self.hidden_size = config.hidden_size - self.self_attn = GaudiMistralAttention(config, layer_idx) - self.mlp = MistralMLP(config) - self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): - self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def allocate_kv_cache(self, batch_size, seq_len): + self.self_attn.allocate_kv_cache(batch_size, seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -419,7 +309,6 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -427,6 +316,7 @@ def forward( The only differences are: - add new args token_idx """ + residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -443,7 +333,6 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, - use_fused_rope=use_fused_rope, ) hidden_states = residual + hidden_states @@ -465,9 +354,9 @@ def forward( class GaudiMistralModel(MistralModel): - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + def allocate_kv_cache(self, batch_size, seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + layer.allocate_kv_cache(batch_size, seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -487,14 +376,11 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, - lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args token_idx - - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -524,16 +410,12 @@ def forward( past_key_values_length = 0 use_legacy_cache = True use_new_cache = False - if past_key_values is not None and use_cache: - if reuse_cache: - # past_seen_tokens = past_key_values[0][0][2] - pass - else: - if use_new_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if past_key_values is not None and use_cache and not reuse_cache: + if use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -573,15 +455,8 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None - if lazy_mode: - htcore.mark_step() - for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx == len(self.layers) // 2 or ( - lazy_mode - and not self.training - and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) - ): + if layer_idx == len(self.layers) // 2: htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,) @@ -596,7 +471,6 @@ def forward( output_attentions, use_cache, None, - use_fused_rope, ) else: layer_outputs = decoder_layer( @@ -610,7 +484,6 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, - use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -645,8 +518,8 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): - self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def allocate_kv_cache(self, batch_size, seq_len, _): + self.model.allocate_kv_cache(batch_size, seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) @@ -671,8 +544,6 @@ def forward( trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, - use_fused_rope: Optional[bool] = True, - lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -701,8 +572,6 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, - use_fused_rope=use_fused_rope, - lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -720,11 +589,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Enable model parallelism + # Ensure tensors are on the same device shift_labels = shift_labels.to(shift_logits.device) + loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: @@ -813,28 +682,6 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "cache_idx": kwargs.get("cache_idx"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), - "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs - - -def apply_customized_rope(q, k, cos, sin, position_ids, use_fused_rope=True): - if q.device.type == "hpu" and has_fused_rope and use_fused_rope: - # TODO: remove `.clone()` when SynapseAI v1.15 is released - if k.dtype == torch.bfloat16: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - position_ids, - ) - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) - else: - return apply_rotary_pos_emb(q, k, cos, sin, position_ids) From 4543a80a7c9f87d350772de5b026096f829ba668 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Fri, 26 Apr 2024 22:22:03 +0000 Subject: [PATCH 12/12] fix formatting --- examples/text-generation/run_lm_eval.py | 3 +- .../models/mistral/modeling_mistral.py | 36 +++++++++++++++++++ tests/test_text_generation_example.py | 7 ++-- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index b086c80b92..8382412bf5 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -152,7 +152,8 @@ def main(): model, tokenizer, generation_config = initialize_model(args, logger) lm_tasks = lm_eval.tasks.get_task_dict(args.tasks) - lm = HabanaModelAdapter(tokenizer, model, args, generation_config) + with torch.no_grad(): + lm = HabanaModelAdapter(tokenizer, model, args, generation_config) eval_start = time.perf_counter() results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit_iters) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index cf5fa6f2c0..f2a153d5d4 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -42,6 +42,11 @@ from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) +from ..llama.modeling_llama import ( + GaudiLlamaDynamicNTKScalingRotaryEmbedding, + GaudiLlamaLinearScalingRotaryEmbedding, + GaudiLlamaRotaryEmbedding, +) try: @@ -128,6 +133,37 @@ def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) self.past_key = None self.past_value = None + self._init_rope() + + def _init_rope(self): + """ + Copied from: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L294 + """ + if self.config.rope_scaling is None: + self.rotary_emb = GaudiLlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = GaudiLlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = GaudiLlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def allocate_kv_cache(self, batch_size, seq_len): kv_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 53cbe5783a..712418d3e0 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -36,7 +36,7 @@ ("google/gemma-7b", 1, False, 109.70751574382221), ], "fp8": [ - #("tiiuae/falcon-180B", 52.85086442722326), + # ("tiiuae/falcon-180B", 52.85086442722326), ("mistralai/Mistral-7B-Instruct-v0.2", 0), ], "deepspeed": [ @@ -86,11 +86,12 @@ "mistralai/Mistral-7B-Instruct-v0.2": [ ("896", "128", "128", 13310.566520719813), ("120", "128", "2048", 7757.383448024244), - ("120", "2048", "128", 1352.070452897798 ), + ("120", "2048", "128", 1352.070452897798), ("44", "2048", "2048", 3101.5205518843136), ], } + def _test_text_generation( model_name: str, baseline: float, @@ -198,7 +199,6 @@ def _test_text_generation( assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline return - proc = subprocess.run(command, env=env_variables) # Ensure the run finished without any issue @@ -244,4 +244,3 @@ def test_text_generation_torch_compile(model_name: str, baseline: float, token: def test_text_generation_torch_compile_distributed(model_name: str, baseline: float, token: str): world_size = 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) -