diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 03ce7a4984..6c5dac919f 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -385,11 +385,13 @@ def forward( reuse_cache: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py The only differences are: - add new args token_idx + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -468,8 +470,13 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): - if layer_idx == len(self.layers)//2: + if layer_idx == len(self.layers)//2 or \ + (lazy_mode and not self.training and \ + (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)): htcore.mark_step() if output_hidden_states: all_hidden_states += (hidden_states,) @@ -557,6 +564,7 @@ def forward( trim_logits: Optional[bool] = False, cache_idx: Optional[int] = None, attn_softmax_bf16: Optional[bool] = False, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py @@ -585,6 +593,7 @@ def forward( reuse_cache=reuse_cache, cache_idx=cache_idx, attn_softmax_bf16=attn_softmax_bf16, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -695,6 +704,7 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "cache_idx": kwargs.get("cache_idx"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs