Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion optimum/habana/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,11 +385,13 @@ def forward(
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from MistralModel.forward: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py
The only differences are:
- add new args token_idx
- add new arg lazy_mode
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -468,8 +470,13 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = () if not use_new_cache else None

if lazy_mode:
htcore.mark_step()

for layer_idx, decoder_layer in enumerate(self.layers):
if layer_idx == len(self.layers)//2:
if layer_idx == len(self.layers)//2 or \
(lazy_mode and not self.training and \
(torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)):
htcore.mark_step()
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand Down Expand Up @@ -557,6 +564,7 @@ def forward(
trim_logits: Optional[bool] = False,
cache_idx: Optional[int] = None,
attn_softmax_bf16: Optional[bool] = False,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
"""
Inherits from MistralForCausalLM: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/mistral/modeling_mistral.py
Expand Down Expand Up @@ -585,6 +593,7 @@ def forward(
reuse_cache=reuse_cache,
cache_idx=cache_idx,
attn_softmax_bf16=attn_softmax_bf16,
lazy_mode=lazy_mode,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -695,6 +704,7 @@ def prepare_inputs_for_generation(
"trim_logits": kwargs.get("trim_logits"),
"cache_idx": kwargs.get("cache_idx"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"lazy_mode": kwargs.get("lazy_mode"),
}
)
return model_inputs