Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
import warnings
import torch.distributed as dist

from apex import amp
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
Expand All @@ -27,9 +26,7 @@
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules

from deepspeed.ops.lamb import FusedLamb

from deepspeed.utils import logger
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer

Expand Down Expand Up @@ -618,7 +615,8 @@ def _configure_zero_optimizer(self, optimizer):
cpu_offload=self.zero_cpu_offload(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor())
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps())
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))

Expand Down Expand Up @@ -724,7 +722,6 @@ def forward(self, *inputs, **kwargs):
return loss

def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):

#Zero stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
Expand Down Expand Up @@ -780,6 +777,8 @@ def backward(self, loss, allreduce_gradients=True):
self.timers('backward_inner').start()

if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.backward(loss)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
Expand Down
Loading