Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 65 additions & 59 deletions deepspeed/runtime/zero/stage1.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import math
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from collections import defaultdict

from deepspeed.runtime.zero.utils import _initialize_parameter_parallel_groups
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_OPTIMIZER_STATES
from deepspeed.utils import logger, log_dist
from deepspeed.ops.op_builder import UtilsBuilder


def get_alignment_padding(flattened_lean_size, sub_partition_id, sub_partition_size):
Expand All @@ -29,54 +29,6 @@ def get_group_alignment_padding(tensor_list, sub_partition_size, sub_partition_c
return group_paddings


def flatten_dense_tensors_sub_partition_aligned(tensor_list,
dp,
max_elements_per_comm,
pg):
assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}"

num_elements = sum(t.numel() for t in tensor_list)
log_dist("Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm),
ranks=[0])

# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(num_elements / dp)

# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(max_elements_per_comm // dp)

if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size

# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size * dp) - num_elements

log_dist(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}",
ranks=[0])
log_dist(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}",
ranks=[0])

if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor]

flat_tensors = _flatten_dense_tensors(aligned_tensor_list)
return flat_tensors


def _single_range_check(current_index, start_index, end_index, tensor_size):
offset = 0
if (current_index >= start_index) and (current_index < end_index):
Expand Down Expand Up @@ -127,6 +79,11 @@ def __init__(self,
max_elements_per_comm=5e8,
elastic_checkpoint=True):

# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

if dp_process_group is not None and partition_size is not None:
raise ValueError("Cannot specify both dp_process_group "
"and partition size")
Expand Down Expand Up @@ -209,7 +166,7 @@ def __init__(self,

# flattens all tensors into single 1d tensor aligned with sub-partition size for later dividing
# RS: create aligned sub-partitions
flat_aligned_params = flatten_dense_tensors_sub_partition_aligned(
flat_aligned_params = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=self.fp16_groups[i],
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=self.max_elems_per_comm[i],
Expand All @@ -218,8 +175,8 @@ def __init__(self,

# TODO: I don't think this does anything?
# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -455,8 +412,8 @@ def get_all_sub_partition_info(tensor_list,

return params_in_rank_sub_partition, params_in_rank_sub_partitions_offsets, params_not_local

@staticmethod
def get_flat_sub_partitions(comm_tensor_list,
def get_flat_sub_partitions(self,
comm_tensor_list,
comm_param_offsets,
sub_partition_size,
dtype,
Expand Down Expand Up @@ -527,7 +484,7 @@ def get_flat_sub_partitions(comm_tensor_list,
partition_params.append(my_params) #flat_tensor_list)
final_param_offsets.append(my_offsets)
assert len(flat_tensor_list) == len(my_offsets), "{} {}".format(len(flat_tensor_list), len(my_offsets))
flat_sub_partitions.append(_flatten_dense_tensors(flat_tensor_list))
flat_sub_partitions.append(self.flatten(flat_tensor_list))
if num_comm_intervals is not None and len(
flat_sub_partitions) < num_comm_intervals:
# logger.info("padding w. sub partitions to ensure uniform communication")
Expand Down Expand Up @@ -569,6 +526,55 @@ def free_grad_in_param_list(self, param_list):
else:
p.grad = None

def flatten_dense_tensors_sub_partition_aligned(self,
tensor_list,
dp,
max_elements_per_comm,
pg):
assert max_elements_per_comm >= dp, f"max_elements_per_comm {max_elements_per_comm} < dp {dp}"

num_elements = sum(t.numel() for t in tensor_list)
log_dist(
"Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm),
ranks=[0])

# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(num_elements / dp)

# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(max_elements_per_comm // dp)

if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size

# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size * dp) - num_elements

log_dist(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}",
ranks=[0])
log_dist(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}",
ranks=[0])

if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
aligned_tensor_list = tensor_list + [pad_tensor]

flat_tensors = self.flatten(aligned_tensor_list)
return flat_tensors

def reduce_scatter_gradients(self,
postscale_gradients,
gradient_predivide_factor,
Expand Down Expand Up @@ -699,8 +705,8 @@ def step(self, closure=None):

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -903,7 +909,7 @@ def _retrieve_group_sub_partition_weights(self,
sub_partition_idx = (comm_idx * num_partitions) + rank
all_sub_partition_weights[sub_partition_idx] = sub_partition_weights

flat_merged_weights = flatten_dense_tensors_sub_partition_aligned(
flat_merged_weights = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_sub_partition_weights,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=max_elems_per_comm,
Expand Down Expand Up @@ -951,7 +957,7 @@ def _partition_base_optimizer_state(self,
return all_partition_states[0]

alignment = dist.get_world_size(group=self.dp_process_group)
flat_merged_partitions = flatten_dense_tensors_sub_partition_aligned(
flat_merged_partitions = self.flatten_dense_tensors_sub_partition_aligned(
tensor_list=all_partition_states,
dp=dist.get_world_size(group=self.dp_process_group),
max_elements_per_comm=max_elems_per_comm,
Expand Down
82 changes: 39 additions & 43 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
'''

import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed.distributed_c10d import _get_global_rank
import torch.distributed as dist
import math
Expand All @@ -16,9 +15,8 @@
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
from deepspeed.ops.adam import DeepSpeedCPUAdam

from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.utils import logger
from ...ops.op_builder import UtilsBuilder

#Toggle this to true to enable correctness test
#with gradient partitioning and without
Expand Down Expand Up @@ -52,28 +50,6 @@ def lcm(x, y):
return x * y // gcd(x, y)


# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()

remaining = num_elements % alignment

if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]

num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list

return _flatten_dense_tensors(padded_tensor_list)


def get_alignment_padding(tensor_list, alignment):
num_elements = sum([tensor.numel() for tensor in tensor_list])
remainder = num_elements % alignment
Expand Down Expand Up @@ -121,11 +97,6 @@ def __init__(self,
gradient_predivide_factor=1.0,
gradient_accumulation_steps=1):

# Load pre-installed or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
logger.info(f"Allgather bucket size {allgather_bucket_size}")
Expand All @@ -143,6 +114,11 @@ def __init__(self,
raise SystemError("Cannot use fp16 without CUDA.")
self.optimizer = init_optimizer

# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

self.timers = timers

self.reduce_scatter = reduce_scatter
Expand Down Expand Up @@ -236,7 +212,7 @@ def __init__(self,

#create flat buffer in CPU and move to GPU
self.fp16_groups_flat.append(
flatten_dense_tensors_aligned(
self.flatten_dense_tensors_aligned(
self.fp16_groups[i],
dist.get_world_size(group=self.dp_process_group)).cuda(
torch.cuda.current_device()))
Expand All @@ -247,8 +223,8 @@ def __init__(self,
f"After Flattening and after emptying param group {i} cache")

# set model fp16 weight to slices of flattened buffer
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -611,6 +587,27 @@ def report_ipg_memory_usage(self, tag, param_elems):
f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}"
)

# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()

remaining = num_elements % alignment

if remaining:
elements_to_add = alignment - remaining
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]

num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list

return self.flatten(padded_tensor_list)

############### Independent Partition Gradient ########################
def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size:
Expand Down Expand Up @@ -1004,7 +1001,7 @@ def are_all_related_partitions_reduced(params_id):
self.param_dict[params_id].grad = None

def flatten_and_print(self, message, tensors, start=0, n=5):
flatten_tensor = _flatten_dense_tensors(tensors)
flatten_tensor = self.flatten(tensors)

def print_func():
logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n))
Expand Down Expand Up @@ -1327,7 +1324,7 @@ def get_flat_partition(self,
if return_tensor_list:
return flat_tensor_list

return _flatten_dense_tensors(flat_tensor_list)
return self.flatten(flat_tensor_list)

def free_grad_in_param_list(self, param_list):
for p in param_list:
Expand Down Expand Up @@ -1419,14 +1416,13 @@ def step(self, closure=None):
#create a flat gradients for parameters updated by this process
# If we are last partition, ensure we have same size grads and partition size, if not pad with zero tensors
if partition_id == dist.get_world_size(group=self.dp_process_group) - 1:
single_grad_partition = flatten_dense_tensors_aligned(
single_grad_partition = self.flatten_dense_tensors_aligned(
self.averaged_gradients[i],
int(self.partition_size[i])).to(
self.single_partition_of_fp32_groups[i].dtype)
else:
single_grad_partition = _flatten_dense_tensors(
self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
single_grad_partition = self.flatten(self.averaged_gradients[i]).to(
self.single_partition_of_fp32_groups[i].dtype)
assert single_grad_partition.numel() == self.partition_size[i], \
"averaged gradients have different number of elements that partition size {} {} {} {}".format(single_grad_partition.numel(), self.partition_size[i], i, partition_id)

Expand Down Expand Up @@ -1507,8 +1503,8 @@ def step(self, closure=None):

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
self.fp16_groups[i])
updated_params = self.unflatten(self.fp16_groups_flat[i],
self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

Expand Down Expand Up @@ -1749,7 +1745,7 @@ def _restore_from_fp32_weights(self, all_state_dict):
merged_partitions = [
sd['single_partition_of_fp32_groups'][i] for sd in all_state_dict
]
flat_merged_partitions = flatten_dense_tensors_aligned(
flat_merged_partitions = self.flatten_dense_tensors_aligned(
merged_partitions,
dist.get_world_size(group=self.dp_process_group))
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
Expand All @@ -1773,7 +1769,7 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states):
partition_id = dist.get_rank(group=self.dp_process_group)
alignment = dist.get_world_size(group=self.dp_process_group)
if torch.is_tensor(all_partition_states[0]):
flat_merged_partitions = flatten_dense_tensors_aligned(
flat_merged_partitions = self.flatten_dense_tensors_aligned(
all_partition_states,
alignment)
dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions)
Expand Down
Loading