Skip to content
7 changes: 7 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,13 @@ def forward(
htcore.mark_step()

for layer_idx, decoder_layer in enumerate(self.layers):
if (
lazy_mode
and not self.training
and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1)
):
htcore.mark_step()

if output_hidden_states:
all_hidden_states += (hidden_states,)

Expand Down