From 570e428d90deddafeac368eb54a65ad1aa8ec36a Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Thu, 29 Feb 2024 17:05:21 +0530 Subject: [PATCH 1/3] Add mark step and inplace residual add in llama model code to reduce memory consumption (#65) * Add mark step and inplace add. Mark step helping in reducing workspace memory by approx twice of (BS,seq len, hidden dim). Inplace add helping in reducing persistent tensors by approc twice of (BS, seq len, hidden dim). Signed-off-by: Puneesh Khanna * Add lazy mode parameter * Move mark step within the loop * Move mark step before the loop * Fix indentation * update in place add only for inference --------- Signed-off-by: Puneesh Khanna --- .../habana/transformers/generation/utils.py | 4 ++ .../models/llama/modeling_llama.py | 56 +++++++++++++------ 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index bad54641d6..673b86524b 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1430,6 +1430,7 @@ def greedy_search( ) # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1781,6 +1782,7 @@ def sample( break # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2227,6 +2229,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile ) + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # if sequential is True, split the input to batches of batch_size and run sequentially @@ -3007,6 +3010,7 @@ def constrained_beam_search( if this_peer_finished_flag.item() == 0.0: break + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 72e4b0fa55..9e2457dfe4 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -43,6 +43,7 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +import habana_frameworks.torch.core as htcore def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -514,7 +515,7 @@ def forward( ) residual = hidden_states - output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, self_attn_weights, present_key_value = self.pre_attn( hidden_states, attention_mask, position_ids, @@ -530,13 +531,12 @@ def forward( cache_idx=cache_idx, **kwargs, ) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) - self.self_attn.attention_all_reduce(output_pre_attn) - output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) - self.mlp.mlp_all_reduce(output_post_attn_pre_mlp) - output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp) - - outputs = (output_post_mlp,) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) @@ -562,7 +562,7 @@ def pre_attn( cache_idx: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) - output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( hidden_states, attention_mask, position_ids, @@ -577,23 +577,33 @@ def pre_attn( flash_attention_recompute, cache_idx=cache_idx, ) - return output_attn, attn_weights, present_key_value + return hidden_states, attn_weights, present_key_value - def post_attn_pre_mlp(self, input, residual): - output_post_attn = self.self_attn.post_attn_forward(input) + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) - hidden_states = residual + output_post_attn - residual = hidden_states + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp.pre_mlp_forward(hidden_states) return hidden_states, residual - def post_mlp(self, input, residual): - output_post_mlp = self.mlp.post_mlp_forward(input) - output = output_post_mlp + residual - return output + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) + + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states class GaudiLlamaModel(LlamaModel): @@ -658,6 +668,8 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -667,6 +679,8 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute + - add new arg flash_attention_causal_mask + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -743,6 +757,9 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + if lazy_mode: + htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -850,6 +867,8 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -878,6 +897,8 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, + use_fused_rope=use_fused_rope, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -1011,6 +1032,7 @@ def prepare_inputs_for_generation( "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs From f066050534255f871c64d8b67e50827a8af47a0d Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Mon, 25 Mar 2024 06:39:39 +0200 Subject: [PATCH 2/3] Fix merge conflicts --- optimum/habana/transformers/models/llama/modeling_llama.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 9e2457dfe4..6241d78f1a 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -668,7 +668,6 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - use_fused_rope: Optional[bool] = True, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ @@ -679,7 +678,6 @@ def forward( - add new args reuse_cache - add new args use_flash_attention - add new arg flash_attention_recompute - - add new arg flash_attention_causal_mask - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -867,7 +865,6 @@ def forward( use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, - use_fused_rope: Optional[bool] = True, lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -897,7 +894,6 @@ def forward( use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, cache_idx=cache_idx, - use_fused_rope=use_fused_rope, lazy_mode=lazy_mode, ) hidden_states = outputs[0] From c966e35123ae8d97a64e2d7f061f420126bb3dd5 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Fri, 5 Apr 2024 11:29:21 +0530 Subject: [PATCH 3/3] Fix make style --- optimum/habana/transformers/models/llama/modeling_llama.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 6241d78f1a..0b0cdae96f 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -45,6 +45,7 @@ import habana_frameworks.torch.core as htcore + def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur if prev.dtype == torch.float8_e4m3fn: