Skip to content

Commit

Permalink
fix_stage3_fp16 (#39171)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan authored Jan 25, 2022
1 parent 9059ef6 commit 8bb509d
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ def _update_params(self):
else:
param.bw_storage.scale_(scale=self._world_size_scaling)
param.fw_storage = _VarBaseWrapper(param)
assert param.fw_storage.grad is None
param.fw_storage._copy_gradient_from(param.bw_storage)
update_list.append(param)
return update_list
Expand Down Expand Up @@ -495,10 +496,9 @@ def reduce(*_):
def _redefine_opt_step(self):
params_slice_func = self._update_params_slice
opt_step = self._optim.step
update_scaler = self._optim.update_scaler

def _opt_step(self):
if not update_scaler:
if not self.update_scaler:
params_slice_func()
if self.offload:
with device_guard(device="cpu"):
Expand Down

0 comments on commit 8bb509d

Please sign in to comment.