Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,8 +910,8 @@ def aligned_size():
def padding_size():
return self._padding_size(param)

def partitioned_size():
return self._partitioned_size(param)
def partition_numel():
return self._partition_numel(param)

def item_override():
param.all_gather()
Expand Down Expand Up @@ -953,7 +953,7 @@ def wrapped(*args, **kwargs):
# Partitioning size utilities
param.aligned_size = aligned_size
param.padding_size = padding_size
param.partitioned_size = partitioned_size
param.partition_numel = partition_numel
param.ds_summary = types.MethodType(ds_summary, param)

param.item = allgather_before(param.item)
Expand All @@ -967,7 +967,7 @@ def _padding_size(self, param):
remainder = param.ds_numel % self.world_size
return (self.world_size - remainder) if remainder else 0

def _partitioned_size(self, param):
def _partition_numel(self, param):
return param.ds_tensor.ds_numel

def _ensure_availability_of_partitioned_params(self, params):
Expand Down
30 changes: 15 additions & 15 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _setup_for_real_optimizer(self):
all_params = list(itertools.chain.from_iterable(self.fp16_groups))

grad_partitions_flat_buffer: Tensor = torch.zeros(
sum(p.ds_tensor.ds_numel for p in all_params),
sum(p.partition_numel() for p in all_params),
dtype=self.dtype,
device=self.device,
pin_memory=self.offload_optimizer_pin_memory)
Expand All @@ -590,8 +590,8 @@ def _setup_for_real_optimizer(self):
param.ds_id] = grad_partitions_flat_buffer.narrow(
0,
offset,
param.ds_tensor.numel())
offset += param.ds_tensor.numel()
param.partition_numel())
offset += param.partition_numel()

def set_lr(self, lr):
"""Set the learning rate."""
Expand Down Expand Up @@ -748,7 +748,7 @@ def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False):
'''if the parameter was initialized in nvme then bring it to the destination buffer directly'''
if src.status == PartitionedParamStatus.NOT_AVAILABLE:
print_rank_0(
f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU"
f"Swapping in {param.ds_id} with partition size {param.partition_numel()} permanently to CPU"
)
param.nvme_swapper.swap_into_buffer(param, dest)
src.data = dest.data
Expand All @@ -767,7 +767,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self):
aggregate_params_count = 0

for j, param_group in enumerate(self.optimizer.param_groups):
params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']])
params_in_group = sum([p.partition_numel() for p in param_group['params']])

flat_buffer_size = params_in_group

Expand Down Expand Up @@ -816,7 +816,7 @@ def _create_fp16_partitions_with_defragmentation(self):

# record total elements of parameter partitions in sub group
self.fp16_partitioned_groups_flat_numel.append(
sum(p.ds_tensor.ds_numel for p in sub_group))
sum(p.partition_numel() for p in sub_group))

# record padding required to align group to world size (only applies to last rank)
rank_requires_padding = dist.get_rank(
Expand All @@ -839,7 +839,7 @@ def _create_fp16_partitions_with_defragmentation(self):
# contiguous flat buffer for all parameters that we created earlier
offset = 0
for sub_group in self.fp16_groups:
sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group)
sub_group_numel = sum(param.partition_numel() for param in sub_group)
self.fp16_partitioned_groups_flat.append(
device_buffer.narrow(0,
offset,
Expand All @@ -851,7 +851,7 @@ def _create_fp16_partitions_with_defragmentation(self):
for param_group_idx, param_group in enumerate(param_groups):
flat_offset = 0
for i, sub_group in enumerate(param_group):
total_elements = sum(p.ds_tensor.ds_numel for p in sub_group)
total_elements = sum(p.partition_numel() for p in sub_group)
print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}")
#Flat buffer may not be available for parameters that reside in NVME
if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[
Expand Down Expand Up @@ -887,7 +887,7 @@ def _create_fp16_partitions_with_defragmentation(self):
if should_create_fp16_flat_reuse_buffer:
max_partition_numel, largest_partition_numel = 0, None
for sub_group in self.fp16_groups:
total_elements = sum(t.ds_tensor.ds_numel for t in sub_group)
total_elements = sum(t.partition_numel() for t in sub_group)
if total_elements > max_partition_numel:
largest_partition_numel = [t.ds_numel for t in sub_group]
max_partition_numel = total_elements
Expand All @@ -905,7 +905,7 @@ def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id):
dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel)
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
print_rank_0(
f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}"
f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.partition_numel()}"
)
param.nvme_swapper.swap_in([param], async_op=False)
dest.data.copy_(partitioned_param.data)
Expand Down Expand Up @@ -935,7 +935,7 @@ def _get_sub_group_partitions(self, sub_group_id):
if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE:
swap_path = param.nvme_swapper.get_path(param, True)
sub_group_partitions.append((partitioned_param,
param.ds_tensor.ds_numel,
param.partition_numel(),
swap_path))
else:
sub_group_partitions.append((partitioned_param,
Expand Down Expand Up @@ -1051,7 +1051,7 @@ def _create_fp32_partitions(self):

def _create_fp16_sub_groups(self, params_group):

params_group_numel = sum([param.partitioned_size() for param in params_group])
params_group_numel = sum([param.partition_numel() for param in params_group])
sub_group_size = self.sub_group_size

if sub_group_size is None or sub_group_size >= params_group_numel:
Expand All @@ -1063,7 +1063,7 @@ def _create_fp16_sub_groups(self, params_group):
for param in params_group:

sub_group.append(param)
local_sub_group_size += param.partitioned_size()
local_sub_group_size += param.partition_numel()

if local_sub_group_size >= sub_group_size or id(param) == id(
params_group[-1]):
Expand Down Expand Up @@ -1633,7 +1633,7 @@ def set_grad_positions(self):
current_offset = 0
for param in group:
param_id = self.get_param_id(param)
num_elements = param.ds_tensor.ds_numel
num_elements = param.partition_numel()

self.grad_position[param_id] = [
int(i),
Expand Down Expand Up @@ -1699,7 +1699,7 @@ def __partition_grads(self,
params_to_release: List[Parameter],
grad_partitions: List[Tensor]) -> None:
for param, grad_partition in zip(params_to_release, grad_partitions):
if param.ds_tensor.ds_numel * dist.get_rank(
if param.partition_numel() * dist.get_rank(
self.dp_process_group) > param.ds_numel:
# this grad partition is empty - don't need to do anything
continue
Expand Down