diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 025b0081d1..a4e0130bab 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -692,6 +692,9 @@ def forward( htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): + if lazy_mode and torch.distributed.is_initialized() == False: + htcore.mark_step() + if output_hidden_states: all_hidden_states += (hidden_states,)