Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,11 +1711,14 @@ def backward(self,
loss,
allreduce_gradients=True,
release_loss=False,
retain_graph=False,
scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
allreduce_gradients: is deprecated, ignored, and will soon be removed'
retain_graph: bool, default: false
forward on user defined choice of retain_graph
"""

see_memory_usage("Engine before backward", force=self.memory_breakdown())
Expand Down Expand Up @@ -1751,29 +1754,29 @@ def backward(self,
self._start_timers(self.engine_timers.backward_inner_timers)

if self.zero_optimization():
self.optimizer.is_gradient_accumulation_boundary = (
self.is_gradient_accumulation_boundary())
self.optimizer.backward(loss)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.amp_enabled():
# AMP requires delaying unscale when inside gradient accumulation boundaries
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not self.is_gradient_accumulation_boundary()
with amp.scale_loss(loss,
self.optimizer,
delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
scaled_loss.backward(retain_graph=retain_graph)
elif self.fp16_enabled():
if self.eigenvalue_enabled():
self.optimizer.backward(loss, create_graph=True, retain_graph=True)
else:
self.optimizer.backward(loss)
self.optimizer.backward(loss, retain_graph=retain_graph)
elif self.bfloat16_enabled():
self.optimizer.backward(loss)
else:
if self.eigenvalue_enabled():
loss.backward(create_graph=True, retain_graph=True)
else:
loss.backward()
loss.backward(retain_graph=retain_graph)

self._stop_timers(self.engine_timers.backward_inner_timers)

Expand Down