diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 2e8eb74d9b..57a85b2715 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1414,6 +1414,7 @@ def greedy_search( ) # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1760,6 +1761,7 @@ def sample( break # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2196,6 +2198,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile ) + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2927,6 +2930,7 @@ def constrained_beam_search( if this_peer_finished_flag.item() == 0.0: break + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ad09959925..9cd78e0edb 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -42,6 +42,7 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +import habana_frameworks.torch.core as htcore @@ -480,7 +481,7 @@ def forward( ) residual = hidden_states - output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, self_attn_weights, present_key_value = self.pre_attn( hidden_states, attention_mask, position_ids, @@ -497,12 +498,12 @@ def forward( use_fused_rope=use_fused_rope, **kwargs, ) - self.self_attn.attention_all_reduce(output_pre_attn) - output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) - self.mlp.mlp_all_reduce(output_post_attn_pre_mlp) - output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) - outputs = (output_post_mlp,) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) @@ -529,7 +530,7 @@ def pre_attn( use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) - output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( hidden_states, attention_mask, position_ids, @@ -545,23 +546,33 @@ def pre_attn( cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) - return output_attn, attn_weights, present_key_value + return hidden_states, attn_weights, present_key_value - def post_attn_pre_mlp(self, input, residual): - output_post_attn = self.self_attn.post_attn_forward(input) + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) - hidden_states = residual + output_post_attn - residual = hidden_states + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp.pre_mlp_forward(hidden_states) return hidden_states, residual - def post_mlp(self, input, residual): - output_post_mlp = self.mlp.post_mlp_forward(input) - output = output_post_mlp + residual - return output + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states class GaudiLlamaModel(LlamaModel): @@ -595,6 +606,7 @@ def forward( flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -605,6 +617,7 @@ def forward( - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -680,6 +693,9 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -791,6 +807,7 @@ def forward( flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -816,6 +833,7 @@ def forward( flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -929,6 +947,7 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs