From 91b5c1ac549bbd636305bc8eda660e748b914a05 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Thu, 11 Apr 2024 15:01:27 -0700 Subject: [PATCH 1/2] port llama related changes/optimizations to mistral if applicable. --- .../transformers/models/mistral/modeling_mistral.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 03ce7a4984..45968a7e78 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -385,11 +385,13 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args token_idx + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -468,6 +470,9 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): if layer_idx == len(self.layers)//2: htcore.mark_step() @@ -557,6 +562,7 @@ def forward( trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -585,6 +591,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -695,6 +702,7 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "cache_idx": kwargs.get("cache_idx"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs From 105ef87ad8863b26b7648eda2f38fcf35a5e358c Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Thu, 11 Apr 2024 16:22:28 -0700 Subject: [PATCH 2/2] add mark step as in https://github.com/huggingface/optimum-habana/pull/875 --- .../habana/transformers/models/mistral/modeling_mistral.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 45968a7e78..6c5dac919f 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -474,7 +474,9 @@ def forward( htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx == len(self.layers)//2: + if layer_idx == len(self.layers)//2 or \ + (lazy_mode and not self.training and \ + (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)): htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,)