From ab65e67505e01660d267feb99701cb36a3aad250 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Mon, 26 Feb 2024 19:01:49 +0200 Subject: [PATCH 1/2] Split the graphs to run with flash_attention on 1x --- optimum/habana/transformers/models/llama/modeling_llama.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 3d90fa4357..8ed5420d45 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -23,6 +23,7 @@ _gaudi_prepare_4d_causal_attention_mask, ) +import habana_frameworks.torch.core as htcore try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -676,6 +677,9 @@ def forward( next_decoder_cache = () if not use_new_cache else None for layer_idx, decoder_layer in enumerate(self.layers): + if torch.distributed.is_initialized() == False: + htcore.mark_step() + if output_hidden_states: all_hidden_states += (hidden_states,) From d4d1b9cc4078c699704106a9b1fd2b1662e2ab2c Mon Sep 17 00:00:00 2001 From: Kalyan Date: Fri, 1 Mar 2024 09:01:37 +0200 Subject: [PATCH 2/2] Added lazy_mode check and removed additional htcore import --- optimum/habana/transformers/models/llama/modeling_llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index eb91e25506..a4e0130bab 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -23,7 +23,6 @@ _gaudi_prepare_4d_causal_attention_mask, ) -import habana_frameworks.torch.core as htcore try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -693,7 +692,7 @@ def forward( htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): - if torch.distributed.is_initialized() == False: + if lazy_mode and torch.distributed.is_initialized() == False: htcore.mark_step() if output_hidden_states: