From a29c0b9554d00d2e7636cc3fa609ae3600e365ca Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Mon, 13 Jun 2022 22:25:44 +0500 Subject: [PATCH] Use partition size --- .../runtime/zero/partition_parameters.py | 8 ++--- deepspeed/runtime/zero/stage3.py | 30 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 131d25faedc0..f86e050a0a10 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -910,8 +910,8 @@ def aligned_size(): def padding_size(): return self._padding_size(param) - def partitioned_size(): - return self._partitioned_size(param) + def partition_numel(): + return self._partition_numel(param) def item_override(): param.all_gather() @@ -953,7 +953,7 @@ def wrapped(*args, **kwargs): # Partitioning size utilities param.aligned_size = aligned_size param.padding_size = padding_size - param.partitioned_size = partitioned_size + param.partition_numel = partition_numel param.ds_summary = types.MethodType(ds_summary, param) param.item = allgather_before(param.item) @@ -967,7 +967,7 @@ def _padding_size(self, param): remainder = param.ds_numel % self.world_size return (self.world_size - remainder) if remainder else 0 - def _partitioned_size(self, param): + def _partition_numel(self, param): return param.ds_tensor.ds_numel def _ensure_availability_of_partitioned_params(self, params): diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 6b87cbc13e02..27afc6817f2d 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -579,7 +579,7 @@ def _setup_for_real_optimizer(self): all_params = list(itertools.chain.from_iterable(self.fp16_groups)) grad_partitions_flat_buffer: Tensor = torch.zeros( - sum(p.ds_tensor.ds_numel for p in all_params), + sum(p.partition_numel() for p in all_params), dtype=self.dtype, device=self.device, pin_memory=self.offload_optimizer_pin_memory) @@ -590,8 +590,8 @@ def _setup_for_real_optimizer(self): param.ds_id] = grad_partitions_flat_buffer.narrow( 0, offset, - param.ds_tensor.numel()) - offset += param.ds_tensor.numel() + param.partition_numel()) + offset += param.partition_numel() def set_lr(self, lr): """Set the learning rate.""" @@ -748,7 +748,7 @@ def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' if src.status == PartitionedParamStatus.NOT_AVAILABLE: print_rank_0( - f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" + f"Swapping in {param.ds_id} with partition size {param.partition_numel()} permanently to CPU" ) param.nvme_swapper.swap_into_buffer(param, dest) src.data = dest.data @@ -767,7 +767,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self): aggregate_params_count = 0 for j, param_group in enumerate(self.optimizer.param_groups): - params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) + params_in_group = sum([p.partition_numel() for p in param_group['params']]) flat_buffer_size = params_in_group @@ -816,7 +816,7 @@ def _create_fp16_partitions_with_defragmentation(self): # record total elements of parameter partitions in sub group self.fp16_partitioned_groups_flat_numel.append( - sum(p.ds_tensor.ds_numel for p in sub_group)) + sum(p.partition_numel() for p in sub_group)) # record padding required to align group to world size (only applies to last rank) rank_requires_padding = dist.get_rank( @@ -839,7 +839,7 @@ def _create_fp16_partitions_with_defragmentation(self): # contiguous flat buffer for all parameters that we created earlier offset = 0 for sub_group in self.fp16_groups: - sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) + sub_group_numel = sum(param.partition_numel() for param in sub_group) self.fp16_partitioned_groups_flat.append( device_buffer.narrow(0, offset, @@ -851,7 +851,7 @@ def _create_fp16_partitions_with_defragmentation(self): for param_group_idx, param_group in enumerate(param_groups): flat_offset = 0 for i, sub_group in enumerate(param_group): - total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) + total_elements = sum(p.partition_numel() for p in sub_group) print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") #Flat buffer may not be available for parameters that reside in NVME if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ @@ -887,7 +887,7 @@ def _create_fp16_partitions_with_defragmentation(self): if should_create_fp16_flat_reuse_buffer: max_partition_numel, largest_partition_numel = 0, None for sub_group in self.fp16_groups: - total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) + total_elements = sum(t.partition_numel() for t in sub_group) if total_elements > max_partition_numel: largest_partition_numel = [t.ds_numel for t in sub_group] max_partition_numel = total_elements @@ -905,7 +905,7 @@ def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: print_rank_0( - f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" + f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.partition_numel()}" ) param.nvme_swapper.swap_in([param], async_op=False) dest.data.copy_(partitioned_param.data) @@ -935,7 +935,7 @@ def _get_sub_group_partitions(self, sub_group_id): if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: swap_path = param.nvme_swapper.get_path(param, True) sub_group_partitions.append((partitioned_param, - param.ds_tensor.ds_numel, + param.partition_numel(), swap_path)) else: sub_group_partitions.append((partitioned_param, @@ -1051,7 +1051,7 @@ def _create_fp32_partitions(self): def _create_fp16_sub_groups(self, params_group): - params_group_numel = sum([param.partitioned_size() for param in params_group]) + params_group_numel = sum([param.partition_numel() for param in params_group]) sub_group_size = self.sub_group_size if sub_group_size is None or sub_group_size >= params_group_numel: @@ -1063,7 +1063,7 @@ def _create_fp16_sub_groups(self, params_group): for param in params_group: sub_group.append(param) - local_sub_group_size += param.partitioned_size() + local_sub_group_size += param.partition_numel() if local_sub_group_size >= sub_group_size or id(param) == id( params_group[-1]): @@ -1633,7 +1633,7 @@ def set_grad_positions(self): current_offset = 0 for param in group: param_id = self.get_param_id(param) - num_elements = param.ds_tensor.ds_numel + num_elements = param.partition_numel() self.grad_position[param_id] = [ int(i), @@ -1699,7 +1699,7 @@ def __partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: for param, grad_partition in zip(params_to_release, grad_partitions): - if param.ds_tensor.ds_numel * dist.get_rank( + if param.partition_numel() * dist.get_rank( self.dp_process_group) > param.ds_numel: # this grad partition is empty - don't need to do anything continue