diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 99514c295b..09d9a25ce4 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1425,7 +1425,7 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - if self.args.pipelining_fwd_bwd: + if self.args.use_lazy_mode and self.args.pipelining_fwd_bwd: self.htcore.mark_step() self.accelerator.backward(loss)