diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 7431b2c892c4..76ba6af78b76 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -979,7 +979,7 @@ def clip_fp32_gradients(self): torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping()) - def _take_model_step(self): + def _take_model_step(self, lr_kwargs): if self.gradient_clipping() > 0.0: if not self.fp16_enabled() and not self.amp_enabled(): self.clip_fp32_gradients() @@ -1010,14 +1010,14 @@ def _take_model_step(self): self.skipped_steps += 1 else: if self.lr_scheduler is not None: - self.lr_scheduler.step() + self.lr_scheduler.step(**(lr_kwargs or {})) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) self.global_steps += 1 self.global_samples += self.train_batch_size() - def step(self): + def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ @@ -1034,7 +1034,7 @@ def step(self): if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) - self._take_model_step() + self._take_model_step(lr_kwargs) self.tput_timer.stop(report_progress) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 954774e58912..5c5d896dfc0d 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -940,14 +940,14 @@ def _exec_recv_grads(self, buffer_id): if self.wall_clock_breakdown(): self.timers('pipe_recv_grad').stop() - def _exec_optimizer_step(self): + def _exec_optimizer_step(self, lr_kwargs=None): if self.wall_clock_breakdown(): self.timers('step_microstep').start() self.timers('step').start() self.mem_status('BEFORE STEP', reset_max=True) self._force_grad_boundary = True - self._take_model_step() + self._take_model_step(lr_kwargs) self._force_grad_boundary = False self.mem_status('AFTER STEP')