From 389938ae5c720702cf6114e580269722fc81e811 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Fri, 23 Feb 2024 13:39:31 +0200 Subject: [PATCH 1/6] Add mark step and inplace add. Mark step helping in reducing workspace memory by approx twice of (BS,seq len, hidden dim). Inplace add helping in reducing persistent tensors by approc twice of (BS, seq len, hidden dim). Signed-off-by: Puneesh Khanna --- .../models/llama/modeling_llama.py | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index ead2d75d3d..d5b36f040e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -42,6 +42,7 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None +import habana_frameworks.torch.core as htcore def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur @@ -484,7 +485,7 @@ def forward( ) residual = hidden_states - output_pre_attn, self_attn_weights, present_key_value = self.pre_attn( + hidden_states, self_attn_weights, present_key_value = self.pre_attn( hidden_states, attention_mask, position_ids, @@ -501,12 +502,12 @@ def forward( use_fused_rope=use_fused_rope, **kwargs, ) - self.self_attn.attention_all_reduce(output_pre_attn) - output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual) - self.mlp.mlp_all_reduce(output_post_attn_pre_mlp) - output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp) + self.self_attn.attention_all_reduce(hidden_states) + hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual) + self.mlp.mlp_all_reduce(hidden_states) + hidden_states = self.post_mlp(hidden_states, residual) - outputs = (output_post_mlp,) + outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) @@ -533,7 +534,7 @@ def pre_attn( use_fused_rope: Optional[bool] = True, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: hidden_states = self.input_layernorm(hidden_states) - output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward( + hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward( hidden_states, attention_mask, position_ids, @@ -549,23 +550,26 @@ def pre_attn( cache_idx=cache_idx, use_fused_rope=use_fused_rope, ) - return output_attn, attn_weights, present_key_value + return hidden_states, attn_weights, present_key_value - def post_attn_pre_mlp(self, input, residual): - output_post_attn = self.self_attn.post_attn_forward(input) + def post_attn_pre_mlp(self, hidden_states, residual): + hidden_states = self.self_attn.post_attn_forward(hidden_states) - hidden_states = residual + output_post_attn - residual = hidden_states + residual.add_(hidden_states) + hidden_states = residual hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp.pre_mlp_forward(hidden_states) return hidden_states, residual - def post_mlp(self, input, residual): - output_post_mlp = self.mlp.post_mlp_forward(input) - output = output_post_mlp + residual - return output + def post_mlp(self, hidden_states, residual): + hidden_states = self.mlp.post_mlp_forward(hidden_states) + + residual.add_(hidden_states) + hidden_states = residual + + return hidden_states class GaudiLlamaModel(LlamaModel): @@ -684,6 +688,7 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None + htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) From 094f0b737827a049a615202cdd1f97cde1e5c98d Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Tue, 27 Feb 2024 10:22:05 +0200 Subject: [PATCH 2/6] Add lazy mode parameter --- optimum/habana/transformers/generation/utils.py | 4 ++++ .../habana/transformers/models/llama/modeling_llama.py | 8 +++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 2e8eb74d9b..57a85b2715 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1414,6 +1414,7 @@ def greedy_search( ) # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -1760,6 +1761,7 @@ def sample( break # prepare model inputs + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2196,6 +2198,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1): params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile ) + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) @@ -2927,6 +2930,7 @@ def constrained_beam_search( if this_peer_finished_flag.item() == 0.0: break + model_kwargs["lazy_mode"] = lazy_mode model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 869c8b9bff..6880b68699 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -599,6 +599,7 @@ def forward( flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -609,6 +610,7 @@ def forward( - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask + - add new arg lazy_mode """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -684,7 +686,8 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None - htcore.mark_step() + if lazy_mode: + htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) @@ -796,6 +799,7 @@ def forward( flash_attention_causal_mask: Optional[bool] = False, cache_idx: int = None, use_fused_rope: Optional[bool] = True, + lazy_mode: Optional[bool] = True, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -821,6 +825,7 @@ def forward( flash_attention_causal_mask=flash_attention_causal_mask, cache_idx=cache_idx, use_fused_rope=use_fused_rope, + lazy_mode=lazy_mode, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape @@ -934,6 +939,7 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), } ) return model_inputs From 586e8ce929256c576c6aa2617ea933fed8cd3368 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Wed, 28 Feb 2024 13:24:39 +0530 Subject: [PATCH 3/6] Move mark step within the loop --- optimum/habana/transformers/models/llama/modeling_llama.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 6880b68699..1f57217eaf 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -686,9 +686,10 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None - if lazy_mode: - htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): + if lazy_mode: + htcore.mark_step() + if output_hidden_states: all_hidden_states += (hidden_states,) From f34fae3e2e4238485a27161d10c0767ccce3080b Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Wed, 28 Feb 2024 13:45:31 +0530 Subject: [PATCH 4/6] Move mark step before the loop --- optimum/habana/transformers/models/llama/modeling_llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1f57217eaf..cb769f6b7e 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -686,10 +686,10 @@ def forward( all_self_attns = () if output_attentions else None next_decoder_cache = () if not use_new_cache else None - for layer_idx, decoder_layer in enumerate(self.layers): - if lazy_mode: + if lazy_mode: htcore.mark_step() + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) From 6c6e5da06ec48b556eecd01da7f7afee1a78ea20 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Thu, 29 Feb 2024 10:21:16 +0530 Subject: [PATCH 5/6] Fix indentation --- optimum/habana/transformers/models/llama/modeling_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index cb769f6b7e..d16cbfa8a7 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -687,7 +687,7 @@ def forward( next_decoder_cache = () if not use_new_cache else None if lazy_mode: - htcore.mark_step() + htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: From 8eab266b339b013aea80750b5436896952606894 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Thu, 29 Feb 2024 16:37:09 +0530 Subject: [PATCH 6/6] update in place add only for inference --- .../transformers/models/llama/modeling_llama.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d16cbfa8a7..9cd78e0edb 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -551,8 +551,12 @@ def pre_attn( def post_attn_pre_mlp(self, hidden_states, residual): hidden_states = self.self_attn.post_attn_forward(hidden_states) - residual.add_(hidden_states) - hidden_states = residual + if self.training: + hidden_states = hidden_states + residual + residual = hidden_states + else: + residual.add_(hidden_states) + hidden_states = residual hidden_states = self.post_attention_layernorm(hidden_states) @@ -562,8 +566,11 @@ def post_attn_pre_mlp(self, hidden_states, residual): def post_mlp(self, hidden_states, residual): hidden_states = self.mlp.post_mlp_forward(hidden_states) - residual.add_(hidden_states) - hidden_states = residual + if self.training: + hidden_states = hidden_states + residual + else: + residual.add_(hidden_states) + hidden_states = residual return hidden_states