From 6fca3bb8abe37efe0f0a90b7c44b865e980eda33 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 28 Nov 2022 16:41:04 -0800 Subject: [PATCH 1/8] Coalesce reduce-scatters in distributed Adam --- .../optimizers/distributed_fused_adam.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 7839e34f2..380c92cd7 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -852,13 +852,8 @@ def _start_bucket_grad_sync(self, buckets): else: reduce_op = torch.distributed.ReduceOp.SUM - # Side stream for communication - main_stream = torch.cuda.current_stream() - comm_stream = self._pipeline_streams[-1] - comm_stream.wait_stream(main_stream) - - # Reduce-scatter over distributed process group - for i, bucket in enumerate(buckets): + # Initialize grad state and buffers + for bucket in buckets: if bucket.status == self.GradientStatus.SYNCING: self._finish_bucket_grad_sync() bucket.status = self.GradientStatus.SYNCING @@ -877,30 +872,52 @@ def _start_bucket_grad_sync(self, buckets): dtype=self.grad_sync_dtype, device=self.device, ) - with torch.cuda.stream(comm_stream): + + # Side stream for communication + main_stream = torch.cuda.current_stream() + comm_stream = self._pipeline_streams[-1] + comm_stream.wait_stream(main_stream) + + # Reduce-scatter over distributed process group + if self.distributed_size > 1: + with torch.cuda.stream(comm_stream): + for bucket in buckets: + bucket.sync_wait() + sync_requests = [] + group = self.distributed_process_group + group._start_coalescing() + for bucket in buckets: bucket.sync_request = ( reduce_scatter_tensor( bucket.sync_grads_shard, bucket.grads_bucket, op=reduce_op, - group=self.distributed_process_group, + group=group, async_op=True, ) ) + sync_requests.append(bucket.sync_request) + group._end_coalescing(sync_requests) # All-reduce over redundant process group if self.redundant_size > 1: - for i, bucket in enumerate(buckets): - with torch.cuda.stream(comm_stream): + with torch.cuda.stream(comm_stream): + for bucket in buckets: bucket.sync_wait() + sync_requests = [] + group = self.redundant_process_group + group._start_coalescing() + for bucket in buckets: bucket.sync_request = ( torch.distributed.all_reduce( bucket.sync_grads_shard, op=reduce_op, - group=self.redundant_process_group, + group=group, async_op=True, ) ) + sync_requests.append(bucket.sync_request) + group._end_coalescing(sync_requests) def _finish_bucket_grad_sync(self): """Wait for any gradient synchronizations that are in progress""" From bb5e1376656316e6a3d1a8a266e00cdb8b00d841 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 29 Nov 2022 18:39:48 -0800 Subject: [PATCH 2/8] Support variable-size param buckets in dist Adam optimizer --- .../optimizers/distributed_fused_adam.py | 333 ++++++++++-------- 1 file changed, 181 insertions(+), 152 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 380c92cd7..b329433a5 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -42,6 +42,10 @@ '`--deprecated_fused_adam`.' ) +def _ceildiv(numer, denom): + """Assumes arguments are positive integers""" + return (numer + denom - 1) // denom + def _round_to_multiple(number, multiple, round_up=True): """Assumes arguments are positive integers""" return (number+multiple-1 if round_up else number) // multiple * multiple @@ -228,6 +232,7 @@ def __init__( class StateBucket: def __init__( self, + bucket_size, shard_size, dtype, device, @@ -235,6 +240,8 @@ def __init__( store_param_remainders=False, ): """Optimizer state for a bucket""" + self.bucket_size = bucket_size + self.shard_size = shard_size # Buffer ranges corresponding to parameter fragments self.fragments = [] # Local shard of parameters @@ -435,9 +442,7 @@ def __init__(self, shard_size = int(bucket_size / self.distributed_size) shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False) shard_size = max(shard_size, self.alignment) - bucket_size = shard_size * self.distributed_size - self.bucket_size = bucket_size - self.shard_size = shard_size + self.default_shard_size = shard_size # Optimizer state self.state['buckets'] = [] @@ -528,6 +533,7 @@ def _register_post_backward_hooks(self): def _init_grad_buffer(self): """Allocate contiguous buffer for grad buckets""" + ### TODO Remove grad_buffer_size = 0 for group in self.param_groups: for param in group['params']: @@ -585,95 +591,89 @@ def _init_param_state( ): """Initialize optimizer state for a parameter""" - # Make sure there is at least one bucket - if not self.state['buckets']: + # Param position within bucket + param_size = param.numel() + param_start, param_end = 0, param_size + bucket_id = len(self.state['buckets']) - 1 + if bucket_id >= 0: + bucket = self.state['buckets'][bucket_id] + bucket_size = bucket.bucket_size + shard_size = bucket.shard_size + if bucket.fragments: + _, bucket_start = bucket.fragments[-1].bucket_range + bucket_start = _round_to_multiple( + bucket_start, + self.alignment, + round_up=True, + ) + else: + bucket_start = 0 + else: + bucket_size = 0 + shard_size = 0 + bucket_start = 0 + bucket_end = bucket_start + param_size + + # Create new bucket if param does not fit + if bucket_end > bucket_size: + shard_size = max( + self.default_shard_size, + _ceildiv(param_size, self.distributed_size), + ) + shard_size = _round_to_multiple( + shard_size, + self.alignment, + round_up=True, + ) + bucket_size = shard_size * self.distributed_size self.state['buckets'].append( self.StateBucket( - self.shard_size, + bucket_size, + shard_size, self.dtype, self.device, store_params=self.store_params, store_param_remainders=self.store_param_remainders, ) ) - - # Split parameter values into fragments - # Note: Each fragment resides within a bucket - param_start = 0 - param_size = param.numel() - self.state[param]['fragments'] = [] - while param_start < param_size: - - # Get current bucket - bucket_id = len(self.state['buckets']) - 1 - bucket = self.state['buckets'][bucket_id] - fragment_id = len(bucket.fragments) - - # Determine fragment position within bucket - if fragment_id == 0: - bucket_start = 0 - else: - _, bucket_start = bucket.fragments[-1].bucket_range - bucket_start = _round_to_multiple(bucket_start, self.alignment) - fragment_size = min(param_size-param_start, self.bucket_size-bucket_start) - param_end = param_start + fragment_size - bucket_end = bucket_start + fragment_size - - # Create new bucket if current one is full - if fragment_size <= 0: - self.state['buckets'].append( - self.StateBucket( - self.shard_size, - self.dtype, - self.device, - store_params=self.store_params, - store_param_remainders=self.store_param_remainders, - ) - ) - continue - - # Fragment position within local shard - shard_id = self.distributed_rank - shard_start = bucket_start - self.shard_size*shard_id - shard_end = bucket_end - self.shard_size*shard_id - shard_start = min(max(shard_start, 0), self.shard_size) - shard_end = min(max(shard_end, 0), self.shard_size) - in_local_shard = shard_start < shard_end - if in_local_shard: - shard_bucket_start = shard_start + self.shard_size*shard_id - shard_bucket_end = shard_bucket_start + shard_end - shard_start - shard_param_start = shard_bucket_start - bucket_start + param_start - shard_param_end = shard_param_start + shard_end - shard_start - else: - shard_bucket_start, shard_bucket_end = None, None - shard_param_start, shard_param_end = None, None - - # Record fragment info - fragment = self.ParameterFragment( - param_group_id=param_group_id, - param_id=param_id, - bucket_id=bucket_id, - param_range=(param_start,param_end), - bucket_range=(bucket_start,bucket_end), - in_local_shard=in_local_shard, - shard_range=(shard_start,shard_end), - shard_bucket_range=(shard_bucket_start,shard_bucket_end), - shard_param_range=(shard_param_start,shard_param_end), - ) - self.state[param]['fragments'].append(fragment) - bucket.fragments.append(fragment) - param_start = param_end + return self._init_param_state(param, param_group_id, param_id) + + # Fragment position within local shard + shard_id = self.distributed_rank + shard_start = bucket_start - shard_size*shard_id + shard_end = bucket_end - shard_size*shard_id + shard_start = min(max(shard_start, 0), shard_size) + shard_end = min(max(shard_end, 0), shard_size) + in_local_shard = shard_start < shard_end + if in_local_shard: + shard_bucket_start = shard_start + shard_size*shard_id + shard_bucket_end = shard_bucket_start + shard_end - shard_start + shard_param_start = shard_bucket_start - bucket_start + param_start + shard_param_end = shard_param_start + shard_end - shard_start + else: + shard_bucket_start, shard_bucket_end = None, None + shard_param_start, shard_param_end = None, None + + # Record fragment info + fragment = self.ParameterFragment( + param_group_id=param_group_id, + param_id=param_id, + bucket_id=bucket_id, + param_range=(param_start,param_end), + bucket_range=(bucket_start,bucket_end), + in_local_shard=in_local_shard, + shard_range=(shard_start,shard_end), + shard_bucket_range=(shard_bucket_start,shard_bucket_end), + shard_param_range=(shard_param_start,shard_param_end), + ) + self.state[param]['fragments'] = [fragment] + bucket.fragments.append(fragment) # Initialize main param buffer - if self.store_params: - for fragment in self.state[param]['fragments']: - if fragment.in_local_shard: - bucket = self.state['buckets'][fragment.bucket_id] - param_start, param_end = fragment.shard_param_range - shard_start, shard_end = fragment.shard_range - model_param_fragment = param.detach().view(-1)[param_start:param_end] - main_param_fragment = bucket.params_shard[shard_start:shard_end] - main_param_fragment.copy_(model_param_fragment) + if self.store_params and in_local_shard: + model_param_fragment = param.detach().view(-1)[shard_param_start:shard_param_end] + main_param_fragment = bucket.params_shard[shard_start:shard_end] + main_param_fragment.copy_(model_param_fragment) def zero_grad(self, set_to_none=True): """Clear parameter gradients""" @@ -683,6 +683,7 @@ def zero_grad(self, set_to_none=True): # Construct views into contiguous grad buffer, if needed if self.contiguous_grad_buffer: + ### TODO Fix self._grad_buffer.zero_() for bucket_id in range(len(self.state['buckets'])): bucket_start = bucket_id * self.bucket_size @@ -711,6 +712,7 @@ def _grad_copy(self, param): # Get fragment position bucket_id = fragment.bucket_id bucket = self._grads_buckets[bucket_id] + bucket_size = self.state['buckets'][bucket_id].bucket_size grad_start, grad_end = fragment.param_range bucket_start, bucket_end = fragment.bucket_range @@ -721,6 +723,7 @@ def _grad_copy(self, param): # Allocate gradient buffer if needed if bucket.grads_bucket is None and self.contiguous_grad_buffer: + ### TODO Fix grad_buffer_start = bucket_id * self.bucket_size grad_buffer_end = grad_buffer_start + self.bucket_size grad_buffer = self._grad_buffer[grad_buffer_start:grad_buffer_end] @@ -730,7 +733,7 @@ def _grad_copy(self, param): bucket.grads_bucket.zero_() if bucket.grads_bucket is None: bucket.grads_bucket = torch.zeros( - [self.bucket_size], + [bucket_size], dtype=self.grad_sync_dtype, device=self.device, ) @@ -746,24 +749,21 @@ def _grad_copy(self, param): param.grad = None def grad_buffer_view(self, param): - """Construct view into grad buffer corresponding to param - - Assumes optimizer is using a contiguous grad buffer. - - """ - assert self.contiguous_grad_buffer + """Construct view into grad buffer corresponding to param""" # Figure out corresponding position in grad buffer - param_fragments = self.state[param]['fragments'] - start_bucket_id = param_fragments[0].bucket_id - start_bucket_offset, _ = param_fragments[0].bucket_range - end_bucket_id = param_fragments[-1].bucket_id - _, end_bucket_offset = param_fragments[-1].bucket_range - buffer_start = start_bucket_id * self.bucket_size + start_bucket_offset - buffer_end = end_bucket_id * self.bucket_size + end_bucket_offset + assert len(self.state[param]['fragments']) == 1 + fragment = self.state[param]['fragments'][0] + bucket_id = fragment.bucket_id + bucket_start, bucket_end = fragment.bucket_range + + # Allocate gradient buffer if needed + bucket = self._grads_buckets[bucket_id] + if bucket.grads_bucket is None: + pass ### TODO Implement # Construct view into grad buffer - flat_buffer = self._grad_buffer[buffer_start:buffer_end] + flat_buffer = bucket.grads_bucket[bucket_start:bucket_end] return flat_buffer.detach().view(param.size()) def _force_bucket_grad_sync(self): @@ -771,11 +771,17 @@ def _force_bucket_grad_sync(self): # Synchronize all unsynchronized buckets Status = self.GradientStatus - buckets = [ - bucket - for bucket_id, bucket in sorted(self._grads_buckets.items()) - if bucket.status not in (Status.READY, Status.SYNCING) - ] + buckets = [] + for bucket_id, bucket in sorted(self._grads_buckets.items()): + if bucket.status not in (Status.READY, Status.SYNCING): + buckets.append(bucket) + if bucket.grads_bucket is None: + bucket_size = self.state['buckets'][bucket_id].bucket_size + bucket.grads_bucket = torch.zeros( + [bucket_size], + dtype=self.grad_sync_dtype, + device=self.device, + ) if buckets: self._start_bucket_grad_sync(buckets) self._finish_bucket_grad_sync() @@ -784,8 +790,9 @@ def _force_bucket_grad_sync(self): for bucket_id in range(len(self.state['buckets'])): bucket = self._grads_buckets[bucket_id] if bucket.grads_shard is None: + shard_size = self.state['buckets'][bucket_id].shard_size bucket.grads_shard = torch.zeros( - [self.shard_size], + [shard_size], dtype=self.grad_sync_dtype, device=self.device, ) @@ -820,6 +827,13 @@ def _try_start_bucket_grad_sync( bucket.grads_generated.add(param) if len(bucket.grads_generated) == len(bucket_fragments): bucket.status = self.GradientStatus.FULLY_FILLED + if bucket.grads_bucket is None: + bucket_size = self.state['buckets'][bucket_id].bucket_size + bucket.grads_bucket = torch.zeros( + [bucket_size], + dtype=self.grad_sync_dtype, + device=self.device, + ) # Launch reductions if enough buckets are ready filled_buckets = [] @@ -836,7 +850,8 @@ def _start_bucket_grad_sync(self, buckets): Gradient synchronization is asynchronous. Involves reduce-scatter over distributed process group and allreduce - over redundant process group. + over redundant process group. Assumes grad bucket buffers are + already initialized. """ @@ -858,17 +873,13 @@ def _start_bucket_grad_sync(self, buckets): self._finish_bucket_grad_sync() bucket.status = self.GradientStatus.SYNCING bucket.grads_generated.clear() - if bucket.grads_bucket is None: - bucket.grads_bucket = torch.zeros( - [self.bucket_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) if self.distributed_size == 1: bucket.sync_grads_shard = bucket.grads_bucket else: + bucket_size = bucket.grads_bucket.numel() + shard_size = bucket_size // self.distributed_size bucket.sync_grads_shard = torch.empty( - [self.shard_size], + [shard_size], dtype=self.grad_sync_dtype, device=self.device, ) @@ -1146,57 +1157,75 @@ def step(self, closure=None, *, grad_scaler=None): return self._grad_scale = self._grad_scale.to(dtype=torch.float32, device=self.device) - # Construct workspace buffers - params_bucket_buffers = [ - torch.empty( - [self.bucket_size], - dtype=self.param_sync_dtype, - device=self.device, - ) - for _ in range(self.pipeline_size) - ] - - # Apply optimizer step to each bucket and synchronize params - self.state['step'] += 1 + # Side stream for communication main_stream = torch.cuda.current_stream() comm_stream = self._pipeline_streams[-1] - for stream in self._pipeline_streams: - stream.wait_stream(main_stream) - for bucket_id in range(len(self.state['buckets'])): - stream_id = bucket_id % self.pipeline_size - stream = self._pipeline_streams[stream_id] - with torch.cuda.stream(stream): - - # Buffers for param sync - params_bucket = params_bucket_buffers[stream_id] - bucket_start = self.distributed_rank * self.shard_size - bucket_end = bucket_start + self.shard_size - params_bucket_shard = params_bucket[bucket_start:bucket_end] - # Apply optimizer step to local shard + # Apply optimizer step to each bucket and synchronize params + # Note: In order to overlap communication and compute, we + # split the optimizer step into three pipeline stages: (1) + # local optimizer step, (2) param all-gather, (3) copy to + # model param buffers. This could be implemented more simply + # by just processing each bucket on an independent CUDA + # stream, but we experience poor overlapping when running with + # a single hardware work queue (i.e. + # CUDA_DEVICE_MAX_CONNECTIONS=1). + self.state['step'] += 1 + num_buckets = len(self.state['buckets']) + params_buckets = {} + params_bucket_shards = {} + for local_step_bucket_id in range(num_buckets + 2): + all_gather_bucket_id = local_step_bucket_id - 1 + copy_bucket_id = local_step_bucket_id - 2 + + # Synchronize compute and communication streams + if self.distributed_size > 1: + if 0 <= all_gather_bucket_id < num_buckets: + comm_stream.wait_stream(main_stream) + if 0 <= copy_bucket_id < num_buckets: + main_stream.wait_stream(comm_stream) + + # Apply optimizer step to local shard + if 0 <= local_step_bucket_id < num_buckets: + bucket_id = local_step_bucket_id + bucket = self.state['buckets'][bucket_id] + bucket_size = bucket.bucket_size + shard_size = bucket.shard_size + bucket_start = self.distributed_rank * shard_size + bucket_end = bucket_start + shard_size + params_buckets[bucket_id] = torch.empty( + [bucket_size], + dtype=self.param_sync_dtype, + device=self.device, + ) + params_bucket_shards[bucket_id] = ( + params_buckets[bucket_id][bucket_start:bucket_end] + ) if self.store_param_remainders: self._local_step_with_param_remainders( bucket_id, - params_bucket_shard, + params_bucket_shards[bucket_id], ) else: - self._local_step(bucket_id, params_bucket_shard) + self._local_step( + bucket_id, + params_bucket_shards[bucket_id], + ) - # All-gather updated parameters - # Note: Call all-gather in main stream to ensure they - # are executed in the correct order. Reconsider when - # tagged collectives are available. + # All-gather updated parameters + if 0 <= all_gather_bucket_id < num_buckets: + bucket_id = all_gather_bucket_id if self.distributed_size > 1: - comm_stream.wait_stream(stream) with torch.cuda.stream(comm_stream): all_gather_into_tensor( - params_bucket, - params_bucket_shard, + params_buckets[bucket_id], + params_bucket_shards[bucket_id], group=self.distributed_process_group, ) - stream.wait_stream(comm_stream) - # Copy values to param buffers + # Copy all-gathered values to param buffers + if 0 <= copy_bucket_id < num_buckets: + bucket_id = copy_bucket_id params_in = [] params_out = [] fragments = self.state['buckets'][bucket_id].fragments @@ -1206,17 +1235,15 @@ def step(self, closure=None, *, grad_scaler=None): param = self.param_groups[param_group_id]['params'][param_id] bucket_start, bucket_end = fragment.bucket_range param_start, param_end = fragment.param_range - params_in.append(params_bucket[bucket_start:bucket_end]) + params_in.append(params_buckets[bucket_id][bucket_start:bucket_end]) params_out.append(param.detach().view(-1)[param_start:param_end]) _multi_tensor_copy( params_in, params_out, dummy_overflow_buf=self._dummy_overflow_buf, ) - - # Synchronize pipeline streams - for stream in self._pipeline_streams: - main_stream.wait_stream(stream) + del params_buckets[bucket_id] + del params_bucket_shards[bucket_id] return loss @@ -1340,6 +1367,7 @@ def state_dict(self, gather_on_root=True): ranks on the root rank (default: True) """ + ### TODO Fix state_dict = super().state_dict() if not gather_on_root: return state_dict @@ -1473,6 +1501,7 @@ def state_dict(self, gather_on_root=True): def load_state_dict(self, state_dict): """Load optimizer state""" + ### TODO Fix # State dict contains state for all ranks if 'gathered_states' in state_dict: From 6856f3de6a6e24b92daf72052fbb257d75f39876 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 1 Dec 2022 14:00:50 -0800 Subject: [PATCH 3/8] Support contiguous grad buffer with variable-size param buckets --- .../optimizers/distributed_fused_adam.py | 110 ++++++++++-------- 1 file changed, 62 insertions(+), 48 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index b329433a5..05c7061f4 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -230,18 +230,25 @@ def __init__( self.shard_param_range = shard_param_range class StateBucket: + """Optimizer state for a bucket""" def __init__( self, bucket_size, shard_size, dtype, device, + contiguous_buffer_offset=0, store_params=False, store_param_remainders=False, ): - """Optimizer state for a bucket""" + # Size of parameter bucket self.bucket_size = bucket_size + # Size of local shard of parameter bucket self.shard_size = shard_size + # Size of the filled region in the bucket + self.filled_size = 0 + # Offset to bucket in contiguous buffers + self.contiguous_buffer_offset = contiguous_buffer_offset # Buffer ranges corresponding to parameter fragments self.fragments = [] # Local shard of parameters @@ -411,8 +418,6 @@ def __init__(self, self.overlap_grad_sync = overlap_grad_sync # Number of buckets to synchronize at a time self.pipeline_size = pipeline_size - # Allocate contiguous buffer for gradients - self.contiguous_grad_buffer = contiguous_grad_buffer # Store params or param remainders if store_param_remainders: @@ -451,6 +456,11 @@ def __init__(self, # Gradient state self._grads_buckets = collections.defaultdict(self.GradientBucket) + # Whether to allocate contiguous buffer for gradients + self.contiguous_grad_buffer = contiguous_grad_buffer + # Contiguous buffer for gradients + self._grad_buffer = None + # Side streams for optimizer step and communication self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size+1)] @@ -476,10 +486,6 @@ def __init__(self, # Attach hooks for gradient synchronization self._register_post_backward_hooks() - # Allocate contiguous gradient buffer if needed - if self.contiguous_grad_buffer: - self._init_grad_buffer() - def _make_post_backward_hook(self, param, param_group_id, param_id): """Create callback function to call after param generates grad @@ -533,19 +539,16 @@ def _register_post_backward_hooks(self): def _init_grad_buffer(self): """Allocate contiguous buffer for grad buckets""" - ### TODO Remove - grad_buffer_size = 0 - for group in self.param_groups: - for param in group['params']: - if param.requires_grad: - grad_size = _round_to_multiple(param.numel(), self.alignment) - grad_buffer_size += grad_size - grad_buffer_size = _round_to_multiple( - grad_buffer_size, - self.bucket_size, - ) + self.init_params() # Make sure all params are initialized + if self.state['buckets']: + buffer_size = max( + bucket.contiguous_buffer_offset + bucket.bucket_size + for bucket in self.state['buckets'] + ) + else: + buffer_size = 0 self._grad_buffer = torch.zeros( - [grad_buffer_size], + [buffer_size], dtype=self.dtype, device=self.device, ) @@ -599,23 +602,20 @@ def _init_param_state( bucket = self.state['buckets'][bucket_id] bucket_size = bucket.bucket_size shard_size = bucket.shard_size - if bucket.fragments: - _, bucket_start = bucket.fragments[-1].bucket_range - bucket_start = _round_to_multiple( - bucket_start, - self.alignment, - round_up=True, - ) - else: - bucket_start = 0 + bucket_start = _round_to_multiple( + bucket.filled_size, + self.alignment, + round_up=True, + ) else: + bucket = None bucket_size = 0 shard_size = 0 bucket_start = 0 bucket_end = bucket_start + param_size # Create new bucket if param does not fit - if bucket_end > bucket_size: + if bucket is None or bucket_end > bucket_size: shard_size = max( self.default_shard_size, _ceildiv(param_size, self.distributed_size), @@ -626,12 +626,17 @@ def _init_param_state( round_up=True, ) bucket_size = shard_size * self.distributed_size + if bucket is None: + buffer_offset = 0 + else: + buffer_offset = bucket.contiguous_buffer_offset + bucket.bucket_size self.state['buckets'].append( self.StateBucket( bucket_size, shard_size, self.dtype, self.device, + contiguous_buffer_offset=buffer_offset, store_params=self.store_params, store_param_remainders=self.store_param_remainders, ) @@ -668,6 +673,7 @@ def _init_param_state( ) self.state[param]['fragments'] = [fragment] bucket.fragments.append(fragment) + bucket.filled_size = bucket_end # Initialize main param buffer if self.store_params and in_local_shard: @@ -683,13 +689,15 @@ def zero_grad(self, set_to_none=True): # Construct views into contiguous grad buffer, if needed if self.contiguous_grad_buffer: - ### TODO Fix + if self._grad_buffer is None: + self._init_grad_buffer() self._grad_buffer.zero_() - for bucket_id in range(len(self.state['buckets'])): - bucket_start = bucket_id * self.bucket_size - bucket_end = bucket_start + self.bucket_size - bucket = self._grads_buckets[bucket_id] - bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end] + for bucket_id, bucket in enumerate(self.state['buckets']): + bucket_size = bucket.bucket_size + grad_buffer_start = bucket.contiguous_buffer_offset + grad_buffer_end = bucket_start + bucket_size + grad_buffer = self._grad_buffer[grad_buffer_start:grad_buffer_end] + self._grads_buckets[bucket_id].grads_bucket = grad_buffer # Reset param grads for param in self.parameters(): @@ -723,9 +731,10 @@ def _grad_copy(self, param): # Allocate gradient buffer if needed if bucket.grads_bucket is None and self.contiguous_grad_buffer: - ### TODO Fix - grad_buffer_start = bucket_id * self.bucket_size - grad_buffer_end = grad_buffer_start + self.bucket_size + if self._grad_buffer is None: + self._init_grad_buffer() + grad_buffer_start = self.state['buckets'][bucket_id].contiguous_buffer_offset + grad_buffer_end = bucket_start + bucket_size grad_buffer = self._grad_buffer[grad_buffer_start:grad_buffer_end] if (bucket.grads_shard is None or bucket.grads_shard.data_ptr() != grad_buffer.data_ptr()): @@ -749,21 +758,27 @@ def _grad_copy(self, param): param.grad = None def grad_buffer_view(self, param): - """Construct view into grad buffer corresponding to param""" + """Construct view into grad buffer corresponding to param + + Assumes optimizer is using a contiguous grad buffer. + + """ + + # Initialize contiguous grad buffer if needed + assert self.contiguous_grad_buffer + if self._grad_buffer is None: + self._init_grad_buffer() # Figure out corresponding position in grad buffer - assert len(self.state[param]['fragments']) == 1 fragment = self.state[param]['fragments'][0] bucket_id = fragment.bucket_id bucket_start, bucket_end = fragment.bucket_range - - # Allocate gradient buffer if needed - bucket = self._grads_buckets[bucket_id] - if bucket.grads_bucket is None: - pass ### TODO Implement + buffer_bucket_start = self.state['buckets'][bucket_id].contiguous_buffer_offset + buffer_start = buffer_bucket_start + buffer_start + buffer_end = buffer_bucket_start + buffer_end # Construct view into grad buffer - flat_buffer = bucket.grads_bucket[bucket_start:bucket_end] + flat_buffer = self._grad_buffer[buffer_start:buffer_end] return flat_buffer.detach().view(param.size()) def _force_bucket_grad_sync(self): @@ -1389,7 +1404,7 @@ def state_dict(self, gather_on_root=True): max_state_size = max(state_sizes) # Construct workspace buffers - chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8 + chunk_size = self.default_shard_size * torch.finfo(self.grad_sync_dtype).bits // 8 if self.distributed_rank == 0: gathered_state_bytes = [ torch.empty([size], dtype=torch.uint8, device='cpu') @@ -1501,7 +1516,6 @@ def state_dict(self, gather_on_root=True): def load_state_dict(self, state_dict): """Load optimizer state""" - ### TODO Fix # State dict contains state for all ranks if 'gathered_states' in state_dict: From 0146de4d556ef8537a380ee4909a66908077bf10 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 1 Dec 2022 16:25:04 -0800 Subject: [PATCH 4/8] Add dist Adam unit test with contiguous grad buffers --- .../optimizers/distributed_fused_adam.py | 87 ++++++++++++++++--- .../contrib/test/optimizers/test_dist_adam.py | 11 ++- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 05c7061f4..4137e8de8 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -562,6 +562,8 @@ def parameters(self): def init_params(self, params=None): """Initialize optimizer state for parameters + Ignores parameters that have already been initialized. + Arguments: params (iterable, optional): parameters to initialize (default: all parameters) @@ -574,6 +576,15 @@ def init_params(self, params=None): elif isinstance(params, torch.Tensor): params = [params] + # Ignore parameters that have already been initialized + params = [ + param + for param in params + if 'fragments' not in self.state[param] + ] + if not params: + return + # Get indices corresponding to parameters id_map = dict() for param_group_id, group in enumerate(self.param_groups): @@ -582,10 +593,66 @@ def init_params(self, params=None): # Initialize parameters for param in params: - if param in id_map and 'fragments' not in self.state[param]: + if param in id_map: param_group_id, param_id = id_map[param] self._init_param_state(param, param_group_id, param_id) + def init_params_bucket(self, params): + """Initialize optimizer state for parameters in a single bucket + + Ignores parameters that have already been initialized. + + Arguments: + params (iterable): parameters to initialize + + """ + if isinstance(params, torch.Tensor): + params = [params] + + # Ignore parameters that have already been initialized + params = [ + param + for param in params + if 'fragments' not in self.state[param] + ] + if not params: + return + + # Figure out bucket size + bucket_size = sum( + _round_to_multiple(param.numel(), self.alignment, round_up=True) + for param in params + ) + shard_size = _round_to_multiple( + _ceildiv(bucket_size, self.distributed_size), + self.alignment, + round_up=True, + ) + bucket_size = shard_size * self.distributed_size + + # Create new bucket + if self.state['buckets']: + last_bucket = self.state['buckets'][-1] + buffer_offset = last_bucket.contiguous_buffer_offset + last_bucket.bucket_size + else: + buffer_offset = 0 + bucket = self.StateBucket( + bucket_size, + shard_size, + self.dtype, + self.device, + contiguous_buffer_offset=buffer_offset, + store_params=self.store_params, + store_param_remainders=self.store_param_remainders, + ) + self.state['buckets'].append(bucket) + + # Initialize optimizer state for parameters + self.init_params(params) + + # Mark that bucket is fully filled + bucket.filled_size = bucket_size + def _init_param_state( self, param, @@ -694,9 +761,9 @@ def zero_grad(self, set_to_none=True): self._grad_buffer.zero_() for bucket_id, bucket in enumerate(self.state['buckets']): bucket_size = bucket.bucket_size - grad_buffer_start = bucket.contiguous_buffer_offset - grad_buffer_end = bucket_start + bucket_size - grad_buffer = self._grad_buffer[grad_buffer_start:grad_buffer_end] + buffer_start = bucket.contiguous_buffer_offset + buffer_end = buffer_start + bucket_size + grad_buffer = self._grad_buffer[buffer_start:buffer_end] self._grads_buckets[bucket_id].grads_bucket = grad_buffer # Reset param grads @@ -733,9 +800,9 @@ def _grad_copy(self, param): if bucket.grads_bucket is None and self.contiguous_grad_buffer: if self._grad_buffer is None: self._init_grad_buffer() - grad_buffer_start = self.state['buckets'][bucket_id].contiguous_buffer_offset - grad_buffer_end = bucket_start + bucket_size - grad_buffer = self._grad_buffer[grad_buffer_start:grad_buffer_end] + buffer_start = self.state['buckets'][bucket_id].contiguous_buffer_offset + buffer_end = buffer_start + bucket_size + grad_buffer = self._grad_buffer[buffer_start:buffer_end] if (bucket.grads_shard is None or bucket.grads_shard.data_ptr() != grad_buffer.data_ptr()): bucket.grads_bucket = grad_buffer @@ -773,9 +840,9 @@ def grad_buffer_view(self, param): fragment = self.state[param]['fragments'][0] bucket_id = fragment.bucket_id bucket_start, bucket_end = fragment.bucket_range - buffer_bucket_start = self.state['buckets'][bucket_id].contiguous_buffer_offset - buffer_start = buffer_bucket_start + buffer_start - buffer_end = buffer_bucket_start + buffer_end + buffer_offset = self.state['buckets'][bucket_id].contiguous_buffer_offset + buffer_start = buffer_offset + bucket_start + buffer_end = buffer_offset + bucket_end # Construct view into grad buffer flat_buffer = self._grad_buffer[buffer_start:buffer_end] diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py index 3bb9e3c7d..23c984453 100644 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ b/apex/contrib/test/optimizers/test_dist_adam.py @@ -36,6 +36,7 @@ def make_models( param_sync_dtype=None, device='cuda', overlap_communication=True, + contiguous_buffers=False, store_params=False, store_param_remainders=False, ): @@ -78,6 +79,7 @@ def make_models( bucket_cap_mb=71/(4*1024*1024), dtype=optim_dtype, param_sync_dtype=param_sync_dtype, + contiguous_grad_buffer=contiguous_buffers, store_params=store_params, store_param_remainders=store_param_remainders, **optim_args, @@ -115,6 +117,7 @@ def test_matches_pytorch( optim_dtype=None, param_sync_dtype=None, device='cuda', + contiguous_buffers=False, store_params=False, store_param_remainders=False, ): @@ -131,6 +134,7 @@ def test_matches_pytorch( param_sync_dtype=param_sync_dtype, device=device, overlap_communication=overlap_communication, + contiguous_buffers=contiguous_buffers, store_params=store_params, store_param_remainders=store_param_remainders, ) @@ -182,9 +186,7 @@ def test_matches_pytorch( dist_param, ref_param, rtol=rtol, atol=atol) def test_matches_pytorch_l2_reg(self): - self.test_matches_pytorch( - adam_w_mode=False, - ) + self.test_matches_pytorch(adam_w_mode=False) def test_matches_pytorch_no_overlap(self): self.test_matches_pytorch( @@ -195,6 +197,9 @@ def test_matches_pytorch_no_overlap(self): def test_matches_pytorch_sync_every_step(self): self.test_matches_pytorch(use_nosync=False) + def test_matches_pytorch_contiguous_buffers(self): + self.test_matches_pytorch(contiguous_buffers=True) + def test_matches_pytorch_fp64(self): self.test_matches_pytorch( rtol=1.3e-6, From 5d04d07d412a28e2ffea4492b43c7d8d746f430e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 1 Dec 2022 18:31:48 -0800 Subject: [PATCH 5/8] Optimize compute/communication overlap in dist Adam optim step --- apex/contrib/optimizers/distributed_fused_adam.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 4137e8de8..3a37193d1 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -539,6 +539,7 @@ def _register_post_backward_hooks(self): def _init_grad_buffer(self): """Allocate contiguous buffer for grad buckets""" + self.contiguous_grad_buffer = True self.init_params() # Make sure all params are initialized if self.state['buckets']: buffer_size = max( @@ -1260,13 +1261,6 @@ def step(self, closure=None, *, grad_scaler=None): all_gather_bucket_id = local_step_bucket_id - 1 copy_bucket_id = local_step_bucket_id - 2 - # Synchronize compute and communication streams - if self.distributed_size > 1: - if 0 <= all_gather_bucket_id < num_buckets: - comm_stream.wait_stream(main_stream) - if 0 <= copy_bucket_id < num_buckets: - main_stream.wait_stream(comm_stream) - # Apply optimizer step to local shard if 0 <= local_step_bucket_id < num_buckets: bucket_id = local_step_bucket_id @@ -1294,6 +1288,13 @@ def step(self, closure=None, *, grad_scaler=None): params_bucket_shards[bucket_id], ) + # Synchronize compute and communication streams if needed + if self.distributed_size > 1: + if 0 <= all_gather_bucket_id < num_buckets: + comm_stream.wait_stream(main_stream) + if 0 <= copy_bucket_id < num_buckets: + main_stream.wait_stream(comm_stream) + # All-gather updated parameters if 0 <= all_gather_bucket_id < num_buckets: bucket_id = all_gather_bucket_id From 2e16b21104d0172660045f86df9f413baf25b1cf Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 2 Dec 2022 11:17:27 -0800 Subject: [PATCH 6/8] Restore Dist Adam default of splitting params across default-sized buckets --- .../optimizers/distributed_fused_adam.py | 160 ++++++++++-------- 1 file changed, 88 insertions(+), 72 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 3a37193d1..d35d5d66a 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -662,42 +662,11 @@ def _init_param_state( ): """Initialize optimizer state for a parameter""" - # Param position within bucket - param_size = param.numel() - param_start, param_end = 0, param_size - bucket_id = len(self.state['buckets']) - 1 - if bucket_id >= 0: - bucket = self.state['buckets'][bucket_id] - bucket_size = bucket.bucket_size - shard_size = bucket.shard_size - bucket_start = _round_to_multiple( - bucket.filled_size, - self.alignment, - round_up=True, - ) - else: - bucket = None - bucket_size = 0 - shard_size = 0 - bucket_start = 0 - bucket_end = bucket_start + param_size - - # Create new bucket if param does not fit - if bucket is None or bucket_end > bucket_size: - shard_size = max( - self.default_shard_size, - _ceildiv(param_size, self.distributed_size), - ) - shard_size = _round_to_multiple( - shard_size, - self.alignment, - round_up=True, - ) + # Make sure there is at least one bucket + if not self.state['buckets']: + shard_size = self.default_shard_size bucket_size = shard_size * self.distributed_size - if bucket is None: - buffer_offset = 0 - else: - buffer_offset = bucket.contiguous_buffer_offset + bucket.bucket_size + buffer_offset = 0 self.state['buckets'].append( self.StateBucket( bucket_size, @@ -709,45 +678,92 @@ def _init_param_state( store_param_remainders=self.store_param_remainders, ) ) - return self._init_param_state(param, param_group_id, param_id) - - # Fragment position within local shard - shard_id = self.distributed_rank - shard_start = bucket_start - shard_size*shard_id - shard_end = bucket_end - shard_size*shard_id - shard_start = min(max(shard_start, 0), shard_size) - shard_end = min(max(shard_end, 0), shard_size) - in_local_shard = shard_start < shard_end - if in_local_shard: - shard_bucket_start = shard_start + shard_size*shard_id - shard_bucket_end = shard_bucket_start + shard_end - shard_start - shard_param_start = shard_bucket_start - bucket_start + param_start - shard_param_end = shard_param_start + shard_end - shard_start - else: - shard_bucket_start, shard_bucket_end = None, None - shard_param_start, shard_param_end = None, None - - # Record fragment info - fragment = self.ParameterFragment( - param_group_id=param_group_id, - param_id=param_id, - bucket_id=bucket_id, - param_range=(param_start,param_end), - bucket_range=(bucket_start,bucket_end), - in_local_shard=in_local_shard, - shard_range=(shard_start,shard_end), - shard_bucket_range=(shard_bucket_start,shard_bucket_end), - shard_param_range=(shard_param_start,shard_param_end), - ) - self.state[param]['fragments'] = [fragment] - bucket.fragments.append(fragment) - bucket.filled_size = bucket_end + + # Split parameter values into fragments + # Note: Each fragment resides within a bucket + param_start = 0 + param_size = param.numel() + self.state[param]['fragments'] = [] + while param_start < param_size: + + # Get current bucket + bucket_id = len(self.state['buckets']) - 1 + bucket = self.state['buckets'][bucket_id] + fragment_id = len(bucket.fragments) + bucket_size = bucket.bucket_size + shard_size = bucket.shard_size + + # Determine fragment position within bucket + bucket_start = _round_to_multiple( + bucket.filled_size, + self.alignment, + round_up=True, + ) + fragment_size = min(param_size-param_start, bucket_size-bucket_start) + param_end = param_start + fragment_size + bucket_end = bucket_start + fragment_size + + # Create new bucket if current one is full + if fragment_size <= 0: + shard_size = self.default_shard_size + bucket_size = shard_size * self.distributed_size + buffer_offset = bucket.contiguous_buffer_offset + bucket.bucket_size + self.state['buckets'].append( + self.StateBucket( + bucket_size, + shard_size, + self.dtype, + self.device, + contiguous_buffer_offset=buffer_offset, + store_params=self.store_params, + store_param_remainders=self.store_param_remainders, + ) + ) + continue + + # Fragment position within local shard + shard_id = self.distributed_rank + shard_start = bucket_start - shard_size*shard_id + shard_end = bucket_end - shard_size*shard_id + shard_start = min(max(shard_start, 0), shard_size) + shard_end = min(max(shard_end, 0), shard_size) + in_local_shard = shard_start < shard_end + if in_local_shard: + shard_bucket_start = shard_start + shard_size*shard_id + shard_bucket_end = shard_bucket_start + shard_end - shard_start + shard_param_start = shard_bucket_start - bucket_start + param_start + shard_param_end = shard_param_start + shard_end - shard_start + else: + shard_bucket_start, shard_bucket_end = None, None + shard_param_start, shard_param_end = None, None + + # Record fragment info + fragment = self.ParameterFragment( + param_group_id=param_group_id, + param_id=param_id, + bucket_id=bucket_id, + param_range=(param_start,param_end), + bucket_range=(bucket_start,bucket_end), + in_local_shard=in_local_shard, + shard_range=(shard_start,shard_end), + shard_bucket_range=(shard_bucket_start,shard_bucket_end), + shard_param_range=(shard_param_start,shard_param_end), + ) + self.state[param]['fragments'].append(fragment) + bucket.fragments.append(fragment) + bucket.filled_size = bucket_end + param_start = param_end # Initialize main param buffer - if self.store_params and in_local_shard: - model_param_fragment = param.detach().view(-1)[shard_param_start:shard_param_end] - main_param_fragment = bucket.params_shard[shard_start:shard_end] - main_param_fragment.copy_(model_param_fragment) + if self.store_params: + for fragment in self.state[param]['fragments']: + if fragment.in_local_shard: + bucket = self.state['buckets'][fragment.bucket_id] + param_start, param_end = fragment.shard_param_range + shard_start, shard_end = fragment.shard_range + model_param_fragment = param.detach().view(-1)[param_start:param_end] + main_param_fragment = bucket.params_shard[shard_start:shard_end] + main_param_fragment.copy_(model_param_fragment) def zero_grad(self, set_to_none=True): """Clear parameter gradients""" From 4d25813219ce1857be9be0c55faada7b0554cef1 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 5 Dec 2022 20:48:34 -0800 Subject: [PATCH 7/8] Support initializing multiple dist Adam param buckets together The buckets perform communication together, so they are effectively a large bucket. --- .../optimizers/distributed_fused_adam.py | 124 ++++++++++-------- 1 file changed, 70 insertions(+), 54 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index d35d5d66a..19f40b2ce 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -63,8 +63,9 @@ def _multi_tensor_copy( # Group buffers by device and dtype buffer_groups = collections.defaultdict(list) for buf_in, buf_out in zip(buffers_in, buffers_out): - if buf_in.data_ptr() == buf_out.data_ptr(): + if buf_in.data_ptr() == buf_out.data_ptr() or buf_in.numel() == 0: # Nothing to be done if input and output buffers are same + # or have no entries continue if buf_in.dtype == buf_out.dtype: # Just copy bytes if dtypes are same @@ -599,18 +600,20 @@ def init_params(self, params=None): self._init_param_state(param, param_group_id, param_id) def init_params_bucket(self, params): - """Initialize optimizer state for parameters in a single bucket + """Initialize optimizer state for parameters in one effective bucket - Ignores parameters that have already been initialized. + The buckets corresponding to the provided parameters are + configured so they all perform communication together. Ignores + parameters that have already been initialized. Arguments: params (iterable): parameters to initialize """ - if isinstance(params, torch.Tensor): - params = [params] # Ignore parameters that have already been initialized + if isinstance(params, torch.Tensor): + params = [params] params = [ param for param in params @@ -619,40 +622,47 @@ def init_params_bucket(self, params): if not params: return - # Figure out bucket size - bucket_size = sum( - _round_to_multiple(param.numel(), self.alignment, round_up=True) - for param in params - ) - shard_size = _round_to_multiple( - _ceildiv(bucket_size, self.distributed_size), - self.alignment, - round_up=True, - ) - bucket_size = shard_size * self.distributed_size + # Get indices corresponding to parameters + id_map = dict() + for param_group_id, group in enumerate(self.param_groups): + for param_id, param in enumerate(group['params']): + id_map[param] = [param_group_id, param_id] + param_ids = [tuple([param] + id_map[param]) for param in params] - # Create new bucket - if self.state['buckets']: - last_bucket = self.state['buckets'][-1] - buffer_offset = last_bucket.contiguous_buffer_offset + last_bucket.bucket_size - else: - buffer_offset = 0 - bucket = self.StateBucket( - bucket_size, - shard_size, - self.dtype, - self.device, - contiguous_buffer_offset=buffer_offset, - store_params=self.store_params, - store_param_remainders=self.store_param_remainders, - ) - self.state['buckets'].append(bucket) + # Mark existings bucket as fully filled + for bucket in self.state['buckets']: + bucket.filled_size = bucket.bucket_size # Initialize optimizer state for parameters + start_bucket_id = len(self.state['buckets']) self.init_params(params) + end_bucket_id = len(self.state['buckets']) - # Mark that bucket is fully filled - bucket.filled_size = bucket_size + # Make sure all added buckets depend on provided params + for bucket_id in range(start_bucket_id, end_bucket_id): + bucket = self.state['buckets'][bucket_id] + bucket_size = bucket.bucket_size + bucket.filled_size = bucket_size + ids_in_bucket = set( + (fragment.param_group_id, fragment.param_id) + for fragment in bucket.fragments + ) + for param, param_group_id, param_id in param_ids: + if (param_group_id, param_id) not in ids_in_bucket: + param_size = param.numel() + fragment = self.ParameterFragment( + param_group_id=param_group_id, + param_id=param_id, + bucket_id=bucket_id, + param_range=(param_size, param_size), + bucket_range=(bucket_size, bucket_size), + in_local_shard=False, + shard_range=(None, None), + shard_bucket_range=(None, None), + shard_param_range=(None, None), + ) + self.state[param]['fragments'].append(fragment) + bucket.fragments.append(fragment) def _init_param_state( self, @@ -734,6 +744,7 @@ def _init_param_state( shard_param_start = shard_bucket_start - bucket_start + param_start shard_param_end = shard_param_start + shard_end - shard_start else: + shard_start, shard_end = None, None shard_bucket_start, shard_bucket_end = None, None shard_param_start, shard_param_end = None, None @@ -856,10 +867,11 @@ def grad_buffer_view(self, param): # Figure out corresponding position in grad buffer fragment = self.state[param]['fragments'][0] bucket_id = fragment.bucket_id - bucket_start, bucket_end = fragment.bucket_range + param_size = param.numel() + bucket_start, _ = fragment.bucket_range buffer_offset = self.state['buckets'][bucket_id].contiguous_buffer_offset buffer_start = buffer_offset + bucket_start - buffer_end = buffer_offset + bucket_end + buffer_end = buffer_start + param_size # Construct view into grad buffer flat_buffer = self._grad_buffer[buffer_start:buffer_end] @@ -1128,7 +1140,8 @@ def _local_grad_norm(self, parameters=None, norm_type=2.0): if fragment.in_local_shard: bucket = self._grads_buckets[fragment.bucket_id] shard_start, shard_end = fragment.shard_range - grads.append(bucket.grads_shard[shard_start:shard_end]) + if shard_end > shard_start: + grads.append(bucket.grads_shard[shard_start:shard_end]) if grads: grad_norm_sq = multi_tensor_applier( amp_C.multi_tensor_l2norm, @@ -1334,8 +1347,9 @@ def step(self, closure=None, *, grad_scaler=None): param = self.param_groups[param_group_id]['params'][param_id] bucket_start, bucket_end = fragment.bucket_range param_start, param_end = fragment.param_range - params_in.append(params_buckets[bucket_id][bucket_start:bucket_end]) - params_out.append(param.detach().view(-1)[param_start:param_end]) + if param_end > param_start: + params_in.append(params_buckets[bucket_id][bucket_start:bucket_end]) + params_out.append(param.detach().view(-1)[param_start:param_end]) _multi_tensor_copy( params_in, params_out, @@ -1370,13 +1384,14 @@ def _local_step(self, bucket_id, params_out): param_start, param_end = fragment.shard_param_range param_fragment = param.detach().view(-1)[param_start:param_end] param_fragment = param_fragment.to(dtype=self.dtype, device=self.device) - buffers[param_group_id].append([ - param_fragment, - exp_avg[shard_start:shard_end], - exp_avg_sq[shard_start:shard_end], - grads[shard_start:shard_end], - params_out[shard_start:shard_end], - ]) + if shard_end > shard_start: + buffers[param_group_id].append([ + param_fragment, + exp_avg[shard_start:shard_end], + exp_avg_sq[shard_start:shard_end], + grads[shard_start:shard_end], + params_out[shard_start:shard_end], + ]) # Apply optimizer step to each param group for group_id, group_buffers in buffers.items(): @@ -1424,14 +1439,15 @@ def _local_step_with_param_remainders(self, bucket_id, params_out): param = self.param_groups[param_group_id]['params'][param_id] param_fragment = param.detach().view(-1)[param_start:param_end] param_fragment = param_fragment.to(dtype=torch.bfloat16, device=self.device) - buffers[param_group_id].append([ - param_fragment, - param_remainders_shard[shard_start:shard_end], - exp_avg[shard_start:shard_end], - exp_avg_sq[shard_start:shard_end], - grads[shard_start:shard_end], - params_out[shard_start:shard_end], - ]) + if shard_end > shard_start: + buffers[param_group_id].append([ + param_fragment, + param_remainders_shard[shard_start:shard_end], + exp_avg[shard_start:shard_end], + exp_avg_sq[shard_start:shard_end], + grads[shard_start:shard_end], + params_out[shard_start:shard_end], + ]) # Apply optimizer step to each param group for group_id, group_buffers in buffers.items(): From d42f78118047cb680da0aec77ee1e23124e631ef Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 13 Jan 2023 19:30:00 -0800 Subject: [PATCH 8/8] Handle recent change in PyTorch API for coalescing NCCL calls --- .../optimizers/distributed_fused_adam.py | 56 +++++++++++-------- 1 file changed, 32 insertions(+), 24 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 19f40b2ce..e232ee509 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -12,7 +12,7 @@ import amp_C import distributed_adam_cuda -# Fallback to private functions if using older PyTorch version +# Fallback to private functions if using PyTorch <1.13.0 try: from torch.distributed.distributed_c10d import get_global_rank except ImportError: @@ -29,6 +29,16 @@ from torch.distributed.distributed_c10d import _all_gather_base all_gather_into_tensor = _all_gather_base +# Add args to coalescing manager if using PyTorch <=1.13.1 +from torch.distributed.distributed_c10d import _coalescing_manager +if 'device' not in inspect.signature(_coalescing_manager).parameters.keys(): + _coalescing_manager_nodevice = _coalescing_manager + @contextlib.contextmanager + def _coalescing_manager(group, device, reqs): + with _coalescing_manager_nodevice(group, reqs): + yield + +# Import optional CUDA kernels _FOUND_DEPRECATED_FUSED_ADAM = False try: import fused_adam_cuda @@ -1007,19 +1017,18 @@ def _start_bucket_grad_sync(self, buckets): bucket.sync_wait() sync_requests = [] group = self.distributed_process_group - group._start_coalescing() - for bucket in buckets: - bucket.sync_request = ( - reduce_scatter_tensor( - bucket.sync_grads_shard, - bucket.grads_bucket, - op=reduce_op, - group=group, - async_op=True, + with _coalescing_manager(group, self.device, sync_requests): + for bucket in buckets: + bucket.sync_request = ( + reduce_scatter_tensor( + bucket.sync_grads_shard, + bucket.grads_bucket, + op=reduce_op, + group=group, + async_op=True, + ) ) - ) - sync_requests.append(bucket.sync_request) - group._end_coalescing(sync_requests) + sync_requests.append(bucket.sync_request) # All-reduce over redundant process group if self.redundant_size > 1: @@ -1028,18 +1037,17 @@ def _start_bucket_grad_sync(self, buckets): bucket.sync_wait() sync_requests = [] group = self.redundant_process_group - group._start_coalescing() - for bucket in buckets: - bucket.sync_request = ( - torch.distributed.all_reduce( - bucket.sync_grads_shard, - op=reduce_op, - group=group, - async_op=True, + with _coalescing_manager(group, self.device, sync_requests): + for bucket in buckets: + bucket.sync_request = ( + torch.distributed.all_reduce( + bucket.sync_grads_shard, + op=reduce_op, + group=group, + async_op=True, + ) ) - ) - sync_requests.append(bucket.sync_request) - group._end_coalescing(sync_requests) + sync_requests.append(bucket.sync_request) def _finish_bucket_grad_sync(self): """Wait for any gradient synchronizations that are in progress"""