From 3fc88c06734438efae73e6270468b19ce1a00df9 Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Mon, 7 Dec 2020 11:29:00 +0800 Subject: [PATCH] Supported customizing kwargs for lr_scheduler --- deepspeed/runtime/engine.py | 8 ++++---- deepspeed/runtime/pipe/engine.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index ee515a072a91..e22ec1bf01bf 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -978,7 +978,7 @@ def clip_fp32_gradients(self): torch.nn.utils.clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping()) - def _take_model_step(self): + def _take_model_step(self, lr_kwargs): if self.gradient_clipping() > 0.0: if not self.fp16_enabled() and not self.amp_enabled(): self.clip_fp32_gradients() @@ -1009,14 +1009,14 @@ def _take_model_step(self): self.skipped_steps += 1 else: if self.lr_scheduler is not None: - self.lr_scheduler.step() + self.lr_scheduler.step(**(lr_kwargs or {})) if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: self._report_progress(self.global_steps + 1) self.global_steps += 1 self.global_samples += self.train_batch_size() - def step(self): + def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ @@ -1033,7 +1033,7 @@ def step(self): if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) - self._take_model_step() + self._take_model_step(lr_kwargs) self.tput_timer.stop(report_progress) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 0dbcf88eb4e8..c5e243846231 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -937,14 +937,14 @@ def _exec_recv_grads(self, buffer_id): if self.wall_clock_breakdown(): self.timers('pipe_recv_grad').stop() - def _exec_optimizer_step(self): + def _exec_optimizer_step(self, lr_kwargs=None): if self.wall_clock_breakdown(): self.timers('step_microstep').start() self.timers('step').start() self.mem_status('BEFORE STEP', reset_max=True) self._force_grad_boundary = True - self._take_model_step() + self._take_model_step(lr_kwargs) self._force_grad_boundary = False self.mem_status('AFTER STEP')