diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 8382412bf5..3ea74a6a69 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,13 +75,13 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type in ["llama", "falcon"]: + if self.model.config.model_type in ["llama", "mistral", "falcon"]: self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, } ) - if self.model.config.model_type == "llama": + if self.model.config.model_type in ["llama", "mistral"]: self.model_inputs.update( { "attn_softmax_bf16": self.options.attn_softmax_bf16, diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index f2a153d5d4..df99ff67a0 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -34,7 +34,9 @@ MistralAttention, MistralDecoderLayer, MistralForCausalLM, + MistralMLP, MistralModel, + MistralRMSNorm, apply_rotary_pos_emb, ) from transformers.utils import logging @@ -49,6 +51,14 @@ ) +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") + try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm except ImportError: @@ -58,40 +68,53 @@ logger = logging.get_logger(__name__) -def update(prev, cur, dim, idx): - orig_cur = cur - if prev.shape == cur.shape: - # Initialize - prev.copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - return prev.index_copy_(dim, idx - 1, cur) - else: - return torch.cat((prev, cur), dim=dim) +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 - -def gaudi_mistral_rmsnorm_forward(self, hidden_states): - """ - Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py - The only differences are: - - override RMSNorm with Habana fused RMSNorm - """ - if hidden_states.device.type == "hpu" and FusedRMSNorm: - # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype - if hidden_states.dtype != self.weight.dtype: - orig_dtype = hidden_states.dtype - hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) - return hidden_states.to(orig_dtype) + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) else: - hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) - return hidden_states - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) def gaudi_mistral_repeat_kv( @@ -128,11 +151,37 @@ def gaudi_mistral_repeat_kv( return query_states, key_states, value_states, attention_mask +def gaudi_mistral_rmsnorm_forward(self, hidden_states): + """ + Copied from MistralRMSNorm.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py + The only differences are: + - override RMSNorm with Habana fused RMSNorm + """ + if hidden_states.device.type == "hpu" and FusedRMSNorm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + class GaudiMistralAttention(MistralAttention): def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) - self.past_key = None - self.past_value = None + self.k_cache = KVCache() + self.v_cache = KVCache() + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.inp_seq_len = -1 self._init_rope() def _init_rope(self): @@ -165,13 +214,12 @@ def _init_rope(self): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - def allocate_kv_cache(self, batch_size, seq_len): - kv_shape = (batch_size, self.num_key_value_heads, seq_len, self.head_dim) - if self.past_key is None or self.past_key.shape != kv_shape: - device = self.k_proj.weight.device - dtype = self.k_proj.weight.dtype - self.past_key = torch.empty(kv_shape, dtype=dtype, device=device) - self.past_value = torch.empty(kv_shape, dtype=dtype, device=device) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings @@ -186,14 +234,14 @@ def reorder(self, tensor, beam_idx, dim_a, dim_b): tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - if self.past_key is None: + if self.k_cache.cache is None: return (None, None) - head_dim = self.past_key.size(-1) - seq_length = self.past_key.size(-2) - self.reorder(self.past_key, beam_idx, seq_length, head_dim) - self.reorder(self.past_value, beam_idx, seq_length, head_dim) - return (self.past_key.shape, self.past_value.shape) + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) def forward( self, @@ -207,6 +255,7 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -244,35 +293,40 @@ def forward( else: kv_seq_len += kv_shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None or reuse_cache: - if reuse_cache: - past_key = self.past_key - past_value = self.past_value - else: - past_key = past_key_value[0] - past_value = past_key_value[1] - key_states = update(past_key, key_states, 2, token_idx) - value_states = update(past_value, value_states, 2, token_idx) if use_cache: + # reuse k, v, self_attention if reuse_cache: - past_key_value = (key_states.contiguous().shape, value_states.contiguous().shape) + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: - past_key_value = (key_states.contiguous(), value_states.contiguous()) + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] else: past_key_value = None - if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] # repeat k/v heads if n_kv_heads < n_heads query_states, key_states, value_states, attention_mask = gaudi_mistral_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim) if attn_weights.size() not in [ (bsz, self.num_heads, q_len, kv_seq_len), @@ -299,7 +353,7 @@ def forward( # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = self.matmul_av(attn_weights, value_states) attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -321,11 +375,17 @@ def forward( class GaudiMistralDecoderLayer(MistralDecoderLayer): def __init__(self, config: MistralConfig, layer_idx: int): - super().__init__(config, layer_idx) + super(MistralDecoderLayer, self).__init__() + self.hidden_size = config.hidden_size + self.self_attn = GaudiMistralAttention(config, layer_idx) - def allocate_kv_cache(self, batch_size, seq_len): - self.self_attn.allocate_kv_cache(batch_size, seq_len) + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) @@ -345,6 +405,7 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -352,7 +413,6 @@ def forward( The only differences are: - add new args token_idx """ - residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -369,6 +429,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, ) hidden_states = residual + hidden_states @@ -390,9 +451,9 @@ def forward( class GaudiMistralModel(MistralModel): - def allocate_kv_cache(self, batch_size, seq_len): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: - layer.allocate_kv_cache(batch_size, seq_len) + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) @@ -412,11 +473,14 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args token_idx + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -446,12 +510,11 @@ def forward( past_key_values_length = 0 use_legacy_cache = True use_new_cache = False - if past_key_values is not None and use_cache and not reuse_cache: - if use_new_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if past_key_values is not None and use_cache and not reuse_cache and use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -491,8 +554,15 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx == len(self.layers) // 2: + if layer_idx == len(self.layers) // 2 or ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,) @@ -507,6 +577,7 @@ def forward( output_attentions, use_cache, None, + use_fused_rope, ) else: layer_outputs = decoder_layer( @@ -520,6 +591,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, ) hidden_states = layer_outputs[0] @@ -554,8 +626,8 @@ def forward( class GaudiMistralForCausalLM(MistralForCausalLM): - def allocate_kv_cache(self, batch_size, seq_len, _): - self.model.allocate_kv_cache(batch_size, seq_len) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) @@ -580,6 +652,8 @@ def forward( trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -593,6 +667,10 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.generation_config.use_fused_rope is False: + global has_fused_rope + has_fused_rope = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -608,6 +686,8 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + use_fused_rope=use_fused_rope, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -625,11 +705,11 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens + loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) - # Ensure tensors are on the same device + # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) - loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits, shift_labels) if not return_dict: @@ -718,6 +798,28 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "cache_idx": kwargs.get("cache_idx"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs + + +def apply_customized_rope(q, k, cos, sin, position_ids): + if q.device.type == "hpu" and has_fused_rope: + # TODO: remove `.clone()` when SynapseAI v1.15 is released + if k.dtype == torch.bfloat16: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), + position_ids, + ) + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ), FusedRoPE.apply( + k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index f71c89de5d..5956c3f8bf 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -37,6 +37,7 @@ ], "fp8": [ ("tiiuae/falcon-180B", 52.85086442722326), + ("mistralai/Mistral-7B-Instruct-v0.2", 0), ("mistralai/Mixtral-8x7B-v0.1", 39.26845661768185), ("meta-llama/Llama-2-7b-hf", 0.0), ("meta-llama/Llama-2-70b-hf", 0.0), @@ -68,6 +69,14 @@ ("64", "2048", "2048", 2773.173092391251), ], } + MISTRAL_FP8_CONFIG = { + "mistralai/Mistral-7B-Instruct-v0.2": [ + ("896", "128", "128", 12397.11410288204), + ("120", "128", "2048", 5394.675714459493), + ("120", "2048", "128", 919.8470890081497), + ("44", "2048", "2048", 2471.950758729518), + ], + } else: # Gaudi1 CI baselines MODELS_TO_TEST = { @@ -149,7 +158,7 @@ def _test_text_generation( if fp8: if "--trim_logits" not in command: command += ["--trim_logits"] - if "Llama-2" in model_name: + if "Llama-2" in model_name or "Mistral" in model_name: command.remove("--max_new_tokens 100") with TemporaryDirectory() as tmp_dir: @@ -172,12 +181,13 @@ def _test_text_generation( command.insert(-2, "--fp8") command.insert(-2, "--warmup 1") command.insert(-2, "--n_iterations 2") - if "Llama-2" in model_name: + if "Llama-2" in model_name or "Mistral" in model_name: + fp8_model_configs = LLAMA2_FP8_CONFIG if "Llama-2" in model_name else MISTRAL_FP8_CONFIG command.insert(-2, "--limit_hpu_graphs") command.insert(-2, "--max_input_tokens 1") command.insert(-2, "--max_new_tokens 1") command = [x for y in command for x in re.split(pattern, y) if x] - for model_config in LLAMA2_FP8_CONFIG[model_name]: + for model_config in fp8_model_configs[model_name]: command[command.index("--batch_size") + 1] = model_config[0] command[command.index("--max_input_tokens") + 1] = model_config[1] command[command.index("--max_new_tokens") + 1] = model_config[2]