diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bad54641d6..673b86524b 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1430,6 +1430,7 @@ def greedy_search( ) # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1781,6 +1782,7 @@ def sample( break # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2227,6 +2229,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile ) + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # if sequential is True, split the input to batches of batch_size and run sequentially @@ -3007,6 +3010,7 @@ def constrained_beam_search( if this_peer_finished_flag.item() == 0.0: break + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 72e4b0fa55..0b0cdae96f 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -43,6 +43,8 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +import habana_frameworks.torch.core as htcore + def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -514,7 +516,7 @@ def forward( ) residual = hidden_states - output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, self_attn_weights, present_key_value = self.pre_attn( hidden_states, attention_mask, position_ids, @@ -530,13 +532,12 @@ def forward( cache_idx=cache_idx, **kwargs, ) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) - self.self_attn.attention_all_reduce(output_pre_attn) - output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) - self.mlp.mlp_all_reduce(output_post_attn_pre_mlp) - output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp) - - outputs = (output_post_mlp,) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) @@ -562,7 +563,7 @@ def pre_attn( cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) - output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( hidden_states, attention_mask, position_ids, @@ -577,23 +578,33 @@ def pre_attn( flash_attention_recompute, cache_idx=cache_idx, ) - return output_attn, attn_weights, present_key_value + return hidden_states, attn_weights, present_key_value - def post_attn_pre_mlp(self, input, residual): - output_post_attn = self.self_attn.post_attn_forward(input) + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) - hidden_states = residual + output_post_attn - residual = hidden_states + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp.pre_mlp_forward(hidden_states) return hidden_states, residual - def post_mlp(self, input, residual): - output_post_mlp = self.mlp.post_mlp_forward(input) - output = output_post_mlp + residual - return output + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states class GaudiLlamaModel(LlamaModel): @@ -658,6 +669,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -667,6 +679,7 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -743,6 +756,9 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -850,6 +866,7 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -878,6 +895,7 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -1011,6 +1029,7 @@ def prepare_inputs_for_generation( "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs