Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def _do_error_check(self):
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
if self.zero_config.cpu_offload is True:
assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)

def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled
Expand Down
9 changes: 6 additions & 3 deletions deepspeed/runtime/zero/stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,9 +793,12 @@ def _get_groups_without_padding(self, groups_with_padding):
def _get_state_without_padding(self, state_with_padding, padding):
lean_state = {}
for key, value in state_with_padding.items():
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]

if torch.is_tensor(value):
lean_length = value.numel() - padding
lean_state[key] = value[:lean_length]
else:
lean_state[key] = value

return lean_state

# Return base optimizer states.
Expand Down
30 changes: 16 additions & 14 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,10 @@ def independent_gradient_partition_epilogue(self):

if self.overlap_comm:
torch.cuda.synchronize()

if self.cpu_offload is False:
for i, _ in enumerate(self.fp16_groups):

if not i in self.averaged_gradients or self.averaged_gradients[i] is None:
self.averaged_gradients[i] = self.get_flat_partition(
self.params_in_partition[i],
Expand All @@ -498,6 +499,8 @@ def independent_gradient_partition_epilogue(self):

for accumulated_grad, new_avg_grad in zip(self.averaged_gradients[i],avg_new):
accumulated_grad.add_(new_avg_grad)


self._release_ipg_buffers()

# No need to keep the gradients anymore.
Expand Down Expand Up @@ -867,6 +870,7 @@ def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param):

src_tensor = param.grad.view(-1).narrow(0, source_offset, num_elements).float()
dest_tensor.copy_(src_tensor, non_blocking=True)
param.grad=None

def complete_grad_norm_calculation_for_cpu_offload(self, params):
total_norm = 0.0
Expand Down Expand Up @@ -899,25 +903,19 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):

def copy_grads_in_partition(self, param):
if self.cpu_offload:
#print(f"GAS: {self.gradient_accumulation_steps}")
#print(f"GAS: {self.is_gradient_accumulation_boundary}")
#with torch.cuda.stream(torch.cuda.current_stream()):

self.update_overflow_tracker_for_param_grad(param)


if self.gradient_accumulation_steps > 1:
self.async_accumulate_grad_in_cpu_via_gpu(param)

if self.is_gradient_accumulation_boundary:
self.set_norm_for_param_grad_in_gpu(param)

self.update_overflow_tracker_for_param_grad(param)

self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)

#new_grad_tensor = async_copy_to(param.grad.view(-1),
# 'cpu',
# self.cpu_computation_stream)
#param.grad.data = new_grad_tensor.data.view_as(param.grad)
return

#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
if self.grads_in_partition is None:
self.grads_in_partition_offset = 0
total_size = 0
Expand All @@ -938,6 +936,7 @@ def copy_grads_in_partition(self, param):
param.numel())
new_grad_tensor.copy_(param.grad.view(-1))
param.grad.data = new_grad_tensor.data.view_as(param.grad)
#print(f"Grad norm after copy to contiguous_buffer {param.grad.data.norm()}")
self.grads_in_partition_offset += param.numel()

def reduce_ipg_grads(self):
Expand Down Expand Up @@ -1319,6 +1318,7 @@ def free_grad_in_param_list(self, param_list):

def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False
with torch.cuda.stream(self.migration_stream):
for key, value in self.accumulated_grads_in_cpu.items():
value.mul_(0.0)
Expand All @@ -1327,7 +1327,7 @@ def step(self, closure=None):
"""
Not supporting closure.
"""
self.micro_step_id = 0
self.micro_step_id = -1

if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)
Expand All @@ -1346,6 +1346,8 @@ def step(self, closure=None):
self.zero_grad()
if self.cpu_offload:
self.reset_cpu_buffers()
else:
self.averaged_gradients = {}

see_memory_usage('After overflow after clearing gradients')

Expand Down Expand Up @@ -1557,6 +1559,7 @@ def backward(self, loss, retain_graph=False):
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
self.micro_step_id += 1
if self.cpu_offload:
torch.cuda.current_stream().wait_stream(self.migration_stream)

Expand All @@ -1576,7 +1579,6 @@ def backward(self, loss, retain_graph=False):
self.ipg_index = 0

self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
self.micro_step_id += 1

def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
Expand Down