diff --git a/DeepSpeedExamples b/DeepSpeedExamples index fd869ae1c9de..33b9fb77c8ce 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit fd869ae1c9de686f8cb92413efeba83fc989027c +Subproject commit 33b9fb77c8cecdb49118188890f662526d8e9397 diff --git a/deepspeed/pt/deepspeed_checkpointing.py b/deepspeed/pt/deepspeed_checkpointing.py index 2a5bb2ab688b..746adae1c599 100755 --- a/deepspeed/pt/deepspeed_checkpointing.py +++ b/deepspeed/pt/deepspeed_checkpointing.py @@ -602,11 +602,11 @@ def reset(): size_offsets = [] -def _configure_using_config_file(deepspeed_config): +def _configure_using_config_file(deepspeed_config, mpu=None): global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME - config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config + config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config logger.info(config.repr()) PARTITION_ACTIVATIONS = config.partition_activations CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization @@ -684,12 +684,12 @@ def configure( _configure_defaults() - if deepspeed_config is not None: - _configure_using_config_file(deepspeed_config) - if mpu_ is not None: mpu = mpu_ + if deepspeed_config is not None: + _configure_using_config_file(deepspeed_config, mpu=mpu) + if partition_activations is not None: PARTITION_ACTIVATIONS = partition_activations diff --git a/deepspeed/pt/deepspeed_config.py b/deepspeed/pt/deepspeed_config.py index f618124d74c7..3181eb2e33da 100755 --- a/deepspeed/pt/deepspeed_config.py +++ b/deepspeed/pt/deepspeed_config.py @@ -457,10 +457,12 @@ def _do_error_check(self): if self.zero_enabled: assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled" assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) + if self.zero_config.cpu_offload is True: + assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS) assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) - assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format( + assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format( GRADIENT_ACCUMULATION_STEPS) def _do_warning_check(self): diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index c6e7623b1792..5b127eed3595 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -107,7 +107,6 @@ def __init__(self, collate_fn=None, config_params=None): super(DeepSpeedLight, self).__init__() - self.client_optimizer = optimizer self.client_model_parameters = model_parameters self.client_lr_scheduler = lr_scheduler @@ -293,6 +292,9 @@ def zero_reduce_scatter(self): def zero_overlap_comm(self): return self._config.zero_config.overlap_comm + def zero_cpu_offload(self): + return self._config.zero_config.cpu_offload + def zero_optimization_stage(self): return self._config.zero_optimization_stage @@ -492,6 +494,7 @@ def _configure_distributed_model(self, model): # Configure optimizer def _configure_optimizer(self, client_optimizer, model_parameters): + if client_optimizer is not None: basic_optimizer = client_optimizer logger.info('Using client Optimizer as basic optimizer') @@ -505,13 +508,14 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" - if self.optimizer_name() != ADAM_OPTIMIZER: + if self.optimizer_name() not in [ADAM_OPTIMIZER]: assert self.zero_allow_untested_optimizer(), \ 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' logger.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) + self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" @@ -523,8 +527,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters): self.optimizer = self._configure_fp16_optimizer(basic_optimizer) else: self.optimizer = basic_optimizer - - # logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict())) + logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer)) + logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict())) def _configure_basic_optimizer(self, model_parameters): optimizer_parameters = self.optimizer_params() @@ -533,8 +537,11 @@ def _configure_basic_optimizer(self, model_parameters): "'max_grad_norm' is not supported as an optimizer parameter, please switch to using the deepspeed parameter 'gradient_clipping' see: https://www.deepspeed.ai/docs/config-json/#gradient-clipping for more details" ) if self.optimizer_name() == ADAM_OPTIMIZER: - from apex.optimizers.fused_adam import FusedAdam - optimizer = FusedAdam(model_parameters, **optimizer_parameters) + if self.zero_cpu_offload(): + optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) + else: + from apex.optimizers.fused_adam import FusedAdam + optimizer = FusedAdam(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: optimizer = FusedLamb(model_parameters, **optimizer_parameters) else: @@ -613,6 +620,7 @@ def _configure_zero_optimizer(self, optimizer): dp_process_group=self.data_parallel_group, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=self.zero_overlap_comm(), + cpu_offload=self.zero_cpu_offload(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor()) @@ -843,7 +851,6 @@ def step(self): master_params = amp.master_params(self.optimizer) torch.nn.utils.clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping()) - self.optimizer.step() #zero grad in basic optimizer could be unreliable and may not exhibit @@ -946,6 +953,9 @@ def _get_optimizer_param(self, param_name): def get_lr(self): return self._get_optimizer_param('lr') + def get_type(self): + return self._get_optimizer_param('type') + def get_mom(self): return self._get_optimizer_param('betas') diff --git a/deepspeed/pt/deepspeed_zero_config.py b/deepspeed/pt/deepspeed_zero_config.py index 4f654d3b8c30..69b22d4d2bef 100755 --- a/deepspeed/pt/deepspeed_zero_config.py +++ b/deepspeed/pt/deepspeed_zero_config.py @@ -24,6 +24,7 @@ "overlap_comm": [true|false], "reduce_bucket_size": 500000000 "load_from_fp32_weights": [true|false] + "cpu_offload": [true|false] } } ''' @@ -63,21 +64,22 @@ ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights' ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True +ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload' +ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False + ZERO_OPTIMIZATION_DEFAULT = { - ZERO_OPTIMIZATION_STAGE: - ZERO_OPTIMIZATION_STAGE_DEFAULT, + ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS: ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT, - ZERO_OPTIMIZATION_REDUCE_SCATTER: - ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, - ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: - ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, + ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, + ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS: ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE: ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS: - ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT + ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT, + ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT } @@ -93,6 +95,7 @@ def __init__(self, param_dict): self.allgather_bucket_size = None self.overlap_comm = None self.load_from_fp32_weights = None + self.cpu_offload = None if ZERO_OPTIMIZATION in param_dict.keys(): zero_config_dict = param_dict[ZERO_OPTIMIZATION] @@ -157,7 +160,12 @@ def _initialize(self, zero_config_dict): zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) + self.load_from_fp32_weights = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) + + self.cpu_offload = get_scalar_param(zero_config_dict, + ZERO_OPTIMIZATION_CPU_OFFLOAD, + ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT) diff --git a/deepspeed/pt/deepspeed_zero_optimizer.py b/deepspeed/pt/deepspeed_zero_optimizer.py index cbfb249b501d..b0848a8750f2 100755 --- a/deepspeed/pt/deepspeed_zero_optimizer.py +++ b/deepspeed/pt/deepspeed_zero_optimizer.py @@ -9,11 +9,11 @@ import math from torch._six import inf from torch.autograd import Variable +import collections from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS - #Toggle this to true to enable correctness test #with gradient partitioning and without pg_correctness_test = False @@ -99,6 +99,39 @@ def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") +#jie:asyn move to target device +def async_migrate_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + v = obj.cuda(dev, async=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + target = torch.empty_like(obj, device=dev).copy_(obj) + if main_stream is not None: + target.data.record_stream(main_stream) + return target + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + class FP16_DeepSpeedZeroOptimizer(object): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -123,6 +156,7 @@ def __init__(self, dp_process_group=None, reduce_scatter=True, overlap_comm=False, + cpu_offload=False, mpu=None, clip_grad=0.0, allreduce_always_fp32=False, @@ -132,6 +166,7 @@ def __init__(self, if dist.get_rank() == 0: logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {allgather_bucket_size}") + logger.info(f"CPU Offload: {cpu_offload}") # The fused optimizer does all the work. We need this layer for two reason: # 1. maintain same user API from apex.fp16_utils # 2. keep common stuff here in case we need to add ne552w fused optimizer later @@ -151,6 +186,8 @@ def __init__(self, self.overlap_comm = overlap_comm + self.cpu_offload = cpu_offload + self.dp_process_group = dp_process_group self.partition_count = dist.get_world_size(group=self.dp_process_group) @@ -185,7 +222,8 @@ def __init__(self, #a single 32-bit partition of the parallel partitioned parameters #that this process will update self.single_partition_of_fp32_groups = [] - + if cpu_offload: + self.averaged_gradients_on_cpu = {} #param partition info #These are the parameters in each group that will not be updated by this process directly @@ -208,6 +246,7 @@ def __init__(self, # padding on each partition for alignment purposes self.groups_padding = [] + # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify @@ -253,13 +292,24 @@ def __init__(self, self.parallel_partitioned_fp16_groups.append(data_parallel_partitions) # a partition of the fp32 master weights that will be updated by this process - self.single_partition_of_fp32_groups.append( - self.parallel_partitioned_fp16_groups[i] - [partition_id].clone().float().detach()) + if self.cpu_offload: + self.single_partition_of_fp32_groups.append( + async_copy_to(self.parallel_partitioned_fp16_groups[i][partition_id], + 'cpu').float()) + self.averaged_gradients_on_cpu[i] = [ + torch.empty_like( + self.parallel_partitioned_fp16_groups[i][partition_id], + device='cpu') + ] + else: + self.single_partition_of_fp32_groups.append( + self.parallel_partitioned_fp16_groups[i] + [partition_id].clone().float().detach()) # modify optimizer of have flat master weight self.single_partition_of_fp32_groups[ i].requires_grad = True # keep this in case internal optimizer uses it + param_group['params'] = [self.single_partition_of_fp32_groups[i]] partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size( @@ -276,6 +326,8 @@ def __init__(self, self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) self.reduction_stream = torch.cuda.Stream() + self.cpu_computation_stream = torch.cuda.Stream() + self.migration_stream = torch.cuda.Stream() self.callback_queued = False self.param_dict = {} @@ -380,10 +432,16 @@ def _release_ipg_buffers(self): def initialize_optimizer_states(self): for i, group in enumerate(self.fp16_groups): - single_grad_partition = torch.zeros( - int(self.partition_size[i]), - dtype=self.single_partition_of_fp32_groups[i].dtype, - device=torch.cuda.current_device()) + if self.cpu_offload: + single_grad_partition = torch.zeros( + int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device='cpu') + else: + single_grad_partition = torch.zeros( + int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device=torch.cuda.current_device()) self.single_partition_of_fp32_groups[i].grad = single_grad_partition self.optimizer.step() @@ -445,14 +503,15 @@ def independent_gradient_partition_epilogue(self): if self.overlap_comm: torch.cuda.synchronize() - for i, _ in enumerate(self.fp16_groups): - self.averaged_gradients[i] = self.get_flat_partition( - self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=torch.half, - device=torch.cuda.current_device(), - return_tensor_list=True) + if self.cpu_offload is False: + for i, _ in enumerate(self.fp16_groups): + self.averaged_gradients[i] = self.get_flat_partition( + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=torch.half, + device=torch.cuda.current_device(), + return_tensor_list=True) self._release_ipg_buffers() @@ -561,6 +620,7 @@ def report_ipg_memory_usage(self, tag, param_elems): ###############Idependent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) @@ -585,7 +645,15 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.elements_in_ipg_bucket, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) + #jie: + if self.cpu_offload: + with stream(self.migration_stream): + averaged_gradients_on_cpu[i][param_id] = async_copy_to( + new_grad_tensor.data.view_as(param.grad), + 'cpu', + self.cpu_computation_stream) + else: + param.grad.data = new_grad_tensor.data.view_as(param.grad) self.elements_in_ipg_bucket += param.numel() self.grads_in_ipg_bucket.append(param.grad) @@ -697,6 +765,15 @@ def copy_grads_in_partition(self, param): total_size += param_in_partition.numel() see_memory_usage(f"before copying {total_size} gradients into partition") + #jie: + ''' + if self.cpu_offload: + self.grads_in_partition = torch.empty(int(total_size), + dtype=torch.half, + device=torch.cuda.current_device()) + '' + else: + ''' self.grads_in_partition = torch.empty(int(total_size), dtype=torch.half, device=torch.cuda.current_device()) @@ -707,6 +784,11 @@ def copy_grads_in_partition(self, param): self.grads_in_partition_offset, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) + ''' + if self.cpu_offload: + averaged_gradients_on_cpu[self.get_param_id()] = new_grad_tensor.data.view_as(param.grad) + else: + ''' param.grad.data = new_grad_tensor.data.view_as(param.grad) self.grads_in_partition_offset += param.numel() @@ -1122,6 +1204,11 @@ def step(self, closure=None): partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): + if self.cpu_offload: + self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] + #norm_groups.append( + # self.get_grad_norm_direct(self.averaged_gradients[i], + # self.params_in_partition[i])) norm_groups.append( self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) @@ -1147,20 +1234,31 @@ def step(self, closure=None): #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list(self.params_in_partition[i]) - self.averaged_gradients[i] = None + if self.cpu_offload is False: + self.averaged_gradients[i] = None single_partition_grad_groups.append(single_grad_partition) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) + self.cpu_computation_stream.wait_stream(self.migration_stream) timers('optimizer_step').start() - self.optimizer.step() - #get rid of the fp32 gradients. Not needed anymore - for group in self.single_partition_of_fp32_groups: - group.grad = None + with torch.cuda.stream(self.cpu_computation_stream): + self.optimizer.step() - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id].data.copy_(fp32_partition.data) + if self.cpu_offload: + stream = torch.cuda.current_stream() + with torch.cuda.stream(self.migration_stream): + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id] = async_copy_to( + fp32_partition, + torch.cuda.current_device(), + stream) + #for averaged_gradients_cpu, fp32_partition in zip(self.averaged_gradients_on_cpu, self.single_partition_of_fp32_groups): + # averaged_gradients_cpu = [fp32_partition] + else: + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id].data.copy_(fp32_partition.data) timers('optimizer_step').stop() timers('optimizer_allgather').start() @@ -1243,6 +1341,8 @@ def has_overflow_serial(self, params, is_grad_list=False): def has_overflow_partitioned_grads_serial(self): for i in range(len(self.fp16_groups)): + if self.cpu_offload: + self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] for j, grad in enumerate(self.averaged_gradients[i]): if grad is not None and self._has_inf_or_nan(grad.data, j): return True @@ -1317,7 +1417,7 @@ def backward(self, loss, retain_graph=False): device=torch.cuda.current_device()) self.ipg_buffer.append(buf_1) self.ipg_index = 0 - + torch.cuda.empty_cache() self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) def check_overflow(self, partition_gradients=True): @@ -1369,8 +1469,11 @@ def _get_groups_without_padding(self, groups_with_padding): def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): - lean_length = value.numel() - padding - lean_state[key] = value[:lean_length] + if torch.is_tensor(value): + lean_length = value.numel() - padding + lean_state[key] = value[:lean_length] + else: + lean_state[key] = value return lean_state @@ -1412,6 +1515,12 @@ def state_dict(self): self.single_partition_of_fp32_groups) state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding + if self.cpu_offload: + state_dict_tmp = async_copy_to(state_dict, + 'cpu', + torch.cuda.current_stream()) + state_dict = state_dict_tmp + return state_dict # Restore base optimizer fp32 weights from checkpoint by: @@ -1448,11 +1557,16 @@ def refresh_fp32_params(self): def _partition_base_optimizer_state(self, state_key, all_partition_states): partition_id = dist.get_rank(group=self.dp_process_group) alignment = dist.get_world_size(group=self.dp_process_group) - flat_merged_partitions = flatten_dense_tensors_aligned( - all_partition_states, - alignment) - dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) - return dp_partitions[partition_id] + + if torch.is_tensor(all_partition_states[0]): + flat_merged_partitions = flatten_dense_tensors_aligned( + all_partition_states, + alignment) + dp_partitions = self.get_data_parallel_partitions(flat_merged_partitions) + return dp_partitions[partition_id] + else: + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return all_partition_states[0] # Restore base optimizer state from checkpoint by # 1) Merging optimizer state from checkpoints of all partitions @@ -1477,8 +1591,10 @@ def _restore_base_optimizer_state(self, all_state_dict): for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] for key, saved in base_optimizer_group_states[i].items(): - current = self.optimizer.state[p][key] - current.data.copy_(saved.data) + if torch.is_tensor(self.optimizer.state[p][key]): + self.optimizer.state[p][key].data.copy_(saved.data) + else: + self.optimizer.state[p][key] = saved def load_state_dict(self, state_dict_list, diff --git a/deepspeed/pt/deepspeed_zero_utils.py b/deepspeed/pt/deepspeed_zero_utils.py new file mode 100755 index 000000000000..034a9f4ccb22 --- /dev/null +++ b/deepspeed/pt/deepspeed_zero_utils.py @@ -0,0 +1,33 @@ +import torch +from torch.autograd import Variable +import collections + + +def async_migrate_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + v = obj.cuda(dev, async=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + target = torch.empty_like(obj, device=dev).copy_(obj) + if main_stream is not None: + target.data.record_stream(main_stream) + return target + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json index 2a3b9ca5a0be..c3322eca8138 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json @@ -3,14 +3,9 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":1 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "stage": 1 }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json index fde222a3cca2..f6a6db57daf2 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json @@ -3,17 +3,12 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":2, + "stage": 2, "reduce_bucket_size": 7000000, "allgather_bucket_size": 7000000, "reduce_scatter": true }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2_offload.json b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2_offload.json new file mode 100755 index 000000000000..ad054d31bb66 --- /dev/null +++ b/tests/model/Megatron_GPT2/ds_config_func_bs4_zero2_offload.json @@ -0,0 +1,21 @@ +{ + "train_batch_size": 4, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 2, + "reduce_bucket_size": 7000000, + "allgather_bucket_size": 7000000, + "reduce_scatter": true, + "cpu_offload": true + }, + "zero_allow_untested_optimizer": true, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + } +} diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json index 99637973cd60..63b30c225753 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_no_zero.json @@ -3,13 +3,7 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":0 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "stage": 0 }, "gradient_clipping": 1.0, "fp16": { diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json index 8d44659a9ee3..342fd665ccae 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json @@ -2,15 +2,10 @@ "train_batch_size": 8, "gradient_accumulation_steps": 1, "steps_per_print": 1, - "zero_optimization":{ - "stage":1 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "zero_optimization": { + "stage": 1 }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json index fde90e8274b8..0e2582fa102f 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2.json @@ -3,17 +3,12 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":2, + "stage": 2, "reduce_bucket_size": 7000000, "allgather_bucket_size": 7000000, "reduce_scatter": true }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "fp16": { "enabled": true, @@ -26,5 +21,4 @@ "partition_activations": true, "contiguous_memory_optimization": true } - } diff --git a/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2_offload.json b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2_offload.json new file mode 100755 index 000000000000..5c66ed7cc585 --- /dev/null +++ b/tests/model/Megatron_GPT2/ds_config_func_bs8_zero2_offload.json @@ -0,0 +1,25 @@ +{ + "train_batch_size": 8, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 2, + "reduce_bucket_size": 7000000, + "allgather_bucket_size": 7000000, + "reduce_scatter": true, + "cpu_offload": true + }, + "zero_allow_untested_optimizer": true, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "activation_checkpointing": { + "partition_activations": true, + "contiguous_memory_optimization": true + } +} diff --git a/tests/model/Megatron_GPT2/ds_config_func_scheduler.json b/tests/model/Megatron_GPT2/ds_config_func_scheduler.json index 60c810786bf0..2d2ab356e57c 100755 --- a/tests/model/Megatron_GPT2/ds_config_func_scheduler.json +++ b/tests/model/Megatron_GPT2/ds_config_func_scheduler.json @@ -3,14 +3,9 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":2 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } + "stage": 2 }, + "zero_allow_untested_optimizer": true, "gradient_clipping": 1.0, "scheduler": { "type": "WarmupLR", @@ -20,7 +15,6 @@ "warmup_num_steps": 10 } }, - "fp16": { "enabled": true, "loss_scale": 0, diff --git a/tests/model/Megatron_GPT2/ds_config_perf_bs16.json b/tests/model/Megatron_GPT2/ds_config_perf_bs16.json index f160ccd8e610..a40f3e4c7d44 100644 --- a/tests/model/Megatron_GPT2/ds_config_perf_bs16.json +++ b/tests/model/Megatron_GPT2/ds_config_perf_bs16.json @@ -2,7 +2,10 @@ "train_batch_size": 16, "gradient_accumulation_steps": 1, "steps_per_print": 1, - "zero_optimization": 1, + "zero_optimization": { + "stage": 1 + }, + "zero_allow_untested_optimizer": true, "disable_allgather": true, "optimizer": { "type": "Adam", diff --git a/tests/model/Megatron_GPT2/ds_config_perf_bs32.json b/tests/model/Megatron_GPT2/ds_config_perf_bs32.json index 6e23fe687bc8..096a0d3645cd 100755 --- a/tests/model/Megatron_GPT2/ds_config_perf_bs32.json +++ b/tests/model/Megatron_GPT2/ds_config_perf_bs32.json @@ -3,8 +3,9 @@ "gradient_accumulation_steps": 1, "steps_per_print": 1, "zero_optimization": { - "stage":1 + "stage": 1 }, + "zero_allow_untested_optimizer": true, "disable_allgather": true, "optimizer": { "type": "Adam", diff --git a/tests/model/Megatron_GPT2/ds_config_perf_bs8.json b/tests/model/Megatron_GPT2/ds_config_perf_bs8.json index 514496958e14..e793e221e1e7 100644 --- a/tests/model/Megatron_GPT2/ds_config_perf_bs8.json +++ b/tests/model/Megatron_GPT2/ds_config_perf_bs8.json @@ -2,7 +2,10 @@ "train_batch_size": 8, "gradient_accumulation_steps": 1, "steps_per_print": 1, - "zero_optimization": 1, + "zero_optimization": { + "stage": 1 + }, + "zero_allow_untested_optimizer": true, "disable_allgather": true, "optimizer": { "type": "Adam", diff --git a/tests/model/Megatron_GPT2/ds_gpt2_test.sh b/tests/model/Megatron_GPT2/ds_gpt2_test.sh index 5c901f855a33..a8af44df9c7e 100755 --- a/tests/model/Megatron_GPT2/ds_gpt2_test.sh +++ b/tests/model/Megatron_GPT2/ds_gpt2_test.sh @@ -91,9 +91,9 @@ gpt_options=" \ ${ds_opt} \ ${zero_opt} \ " - +DEEPSPEED_PORT=29600 work_dir="../../../DeepSpeedExamples/Megatron-LM/" -run_cmd="(cd ${work_dir} && deepspeed --num_nodes $nodes --num_gpus $gpus pretrain_gpt2.py ${gpt_options})" +run_cmd="(cd ${work_dir} && deepspeed --master_port ${DEEPSPEED_PORT} --num_nodes $nodes --num_gpus $gpus pretrain_gpt2.py ${gpt_options})" echo ${run_cmd} eval ${run_cmd} diff --git a/tests/model/Megatron_GPT2/run_checkpoint_test.py b/tests/model/Megatron_GPT2/run_checkpoint_test.py index 116e58b98fa2..cf11af6c2ae4 100755 --- a/tests/model/Megatron_GPT2/run_checkpoint_test.py +++ b/tests/model/Megatron_GPT2/run_checkpoint_test.py @@ -97,6 +97,29 @@ def test_mp2_gpu4_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp2_gpu4_node1_with_zero2_offload(self): + test_config = { + "mp": 2, + "gpus": 4, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp2_gpu8_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp1_gpu2_load_gpu1_node1_with_zero1(self): test_config = { "mp": 1, @@ -110,7 +133,7 @@ def test_mp1_gpu2_load_gpu1_node1_with_zero1(self): "seq_length": 256, "heads": ATTN_HEADS, "deepspeed": True, - "tag": "ds_zero2", + "tag": "ds_zero1", "zero": True, "other_args": "", "checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero1", @@ -133,7 +156,7 @@ def test_mp1_gpu2_load_gpu4_node1_with_zero1(self): "seq_length": 256, "heads": ATTN_HEADS, "deepspeed": True, - "tag": "ds_zero2", + "tag": "ds_zero1", "zero": True, "other_args": "", "checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero1", @@ -166,6 +189,30 @@ def test_mp1_gpu2_load_gpu1_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp1_gpu2_load_gpu1_node1_with_zero2_offload(self): + test_config = { + "mp": 1, + "gpus": 2, + "load_gpus": 1, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp1_gpu2_gpu1_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp1_gpu2_load_gpu4_node1_with_zero2(self): test_config = { "mp": 1, @@ -189,6 +236,30 @@ def test_mp1_gpu2_load_gpu4_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp1_gpu2_load_gpu4_node1_with_zero2_offload(self): + test_config = { + "mp": 1, + "gpus": 2, + "load_gpus": 4, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp1_gpu2_gpu4_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp2_gpu4_load_gpu2_node1_with_zero1(self): test_config = { "mp": 2, @@ -258,6 +329,30 @@ def test_mp2_gpu4_load_gpu2_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp2_gpu4_load_gpu2_node1_with_zero2_offload(self): + test_config = { + "mp": 2, + "gpus": 4, + "load_gpus": 2, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp2_gpu4_gpu2_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp2_gpu2_load_gpu4_node1_with_zero2(self): test_config = { "mp": 2, @@ -281,6 +376,30 @@ def test_mp2_gpu2_load_gpu4_node1_with_zero2(self): succ = self.run_test(test_config, 0.01) self.assertTrue(succ) + def test_mp2_gpu2_load_gpu4_node1_with_zero2_offload(self): + test_config = { + "mp": 2, + "gpus": 2, + "load_gpus": 4, + "nodes": 1, + "bs": 8, + "steps": 1100, + "layers": LAYERS, + "hidden_size": HIDDEN_SIZE, + "seq_length": 256, + "heads": ATTN_HEADS, + "deepspeed": True, + "tag": "ds_zero2_offload", + "zero": True, + "other_args": "", + "checkpoint_name": "ckpt_mp2_gpu2_gpu4_w_zero2_offload", + "checkpoint_interval": 1000, + "json": "ds_config_func_bs8_zero2_offload.json", + "cpu_optimizer": True, + } + succ = self.run_test(test_config, 0.01) + self.assertTrue(succ) + def test_mp2_gpu4_node1_without_zero(self): test_config = { "mp": 2, @@ -306,7 +425,8 @@ def test_mp2_gpu4_node1_without_zero(self): def gen_name(self, test_config, prefix): save_dir = "checkpoint_test_logs" tag = test_config["tag"] - file_name = f"_{tag}.log" + checkpoint_name = test_config["checkpoint_name"] + file_name = f"_{tag}_{checkpoint_name}.log" return os.path.join(save_dir, prefix + file_name) def run_test(self, test_config, r_tol): @@ -334,10 +454,15 @@ def run_test(self, test_config, r_tol): except: print("No old checkpoint") + if "cpu_optimizer" in test_config and test_config["cpu_optimizer"]: + cpu_optimizer_flag = " --cpu-optimizer" + else: + cpu_optimizer_flag = "" + #-----------------Saving Checkpoint-----------------# - #building checkpoint arguments + # building checkpoint arguments test_config[ - "other_args"] = f"\"--save {checkpoint_folder} --save-interval {checkpoint_interval}\"" + "other_args"] = f"\"--save {checkpoint_folder} --save-interval {checkpoint_interval} {cpu_optimizer_flag}\"" prefix = "gpt2_saving_checkpoint" @@ -356,10 +481,11 @@ def run_test(self, test_config, r_tol): #-----------------Loading Checkpoint-----------------# - #building checkpoint arguments - test_config["other_args"] = f"\"--load {checkpoint_folder}\"" + # building checkpoint arguments + test_config[ + "other_args"] = f"\"--load {checkpoint_folder} {cpu_optimizer_flag} \"" - #set checkpoint load iteration + # set checkpoint load iteration try: cmd = f"echo {checkpoint_interval} > {checkpoint_name}/latest_checkpointed_iteration.txt" print(f"{self.id()} running cmd: {cmd}") @@ -411,20 +537,32 @@ def check_parity(self, base_file, test_file, r_tol): def checkpoint_suite(): suite = unittest.TestSuite() + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero2')) + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_with_zero2_offload')) # Shrink DP suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu1_node1_with_zero2_offload')) + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp2_gpu4_load_gpu2_node1_with_zero2_offload')) # Expand DP suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp1_gpu2_load_gpu4_node1_with_zero2_offload')) + suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero1')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero2')) + suite.addTest( + GPT2CheckpointTestCase('test_mp2_gpu2_load_gpu4_node1_with_zero2_offload')) suite.addTest(GPT2CheckpointTestCase('test_mp2_gpu4_node1_without_zero')) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index d80bd897ee1a..9cce6e40c5ff 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -28,25 +28,25 @@ def compare_model_states(saved_model, loaded_model): compare_deepspeed_states(saved_model, loaded_model) for p0, p1 in zip(saved_model.module.parameters(), loaded_model.module.parameters()): - assert torch.allclose(p0,p1,atol=1e-07), f"FP16 model state {p0} is not equal to {p1}" + assert torch.allclose(p0, p1, atol=1e-07), f"FP16 model state {p0} is not equal to {p1}" if isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer): for p0, p1 in zip(saved_model.optimizer.single_partition_of_fp32_groups, loaded_model.optimizer.single_partition_of_fp32_groups): - assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" + assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, FP16_DeepSpeedZeroOptimizer_Stage1): for partition0, partition1 in zip(saved_model.optimizer.local_sub_partitions_of_fp32_groups, loaded_model.optimizer.local_sub_partitions_of_fp32_groups): for p0, p1 in zip(partition0, partition1): - assert torch.allclose(p0,p1,atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" + assert torch.allclose(p0, p1, atol=1e-07), f"Fp32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, FP16_Optimizer): for p0, p1 in zip(saved_model.optimizer.fp32_groups_flat, loaded_model.optimizer.fp32_groups_flat): - assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" + assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, FP16_UnfusedOptimizer): for params0, params1 in zip(saved_model.optimizer.fp32_groups, loaded_model.optimizer.fp32_groups): for p0, p1 in zip(params0, params1): - assert torch.allclose(p0,p1,atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" + assert torch.allclose(p0, p1, atol=1e-07), f"FP32 model states {p0} is not equal to {p1}" elif isinstance(saved_model.optimizer, torch.optim.Optimizer): pass else: @@ -97,9 +97,9 @@ def checkpoint_correctness_verification(args, load_lr_scheduler_states=False, fp16=True): dtype = torch.half if fp16 else torch.float32 - ds_model, _, _,_ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) + ds_model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) data_loader = random_dataloader(model=ds_model, total_samples=50, hidden_dim=hidden_dim, @@ -117,9 +117,9 @@ def checkpoint_correctness_verification(args, trained_model.save_checkpoint(save_folder, save_tag) - loaded_model, _, _,_ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) + loaded_model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) loaded_model.load_checkpoint(save_folder, save_tag, @@ -235,8 +235,16 @@ def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_sta load_optimizer_states=False) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_checkpoint_zero_optimizer(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (1, + False), + (2, + False), + (2, + True), + ]) +def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -254,8 +262,9 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage - }, + "stage": zero_stage, + "cpu_offload": use_cpu_offload + } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -276,8 +285,16 @@ def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_stat load_optimizer_states=True) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (1, + False), + (2, + False), + (2, + True), + ]) +def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -295,8 +312,9 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage - }, + "stage": zero_stage, + "cpu_offload": use_cpu_offload + } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -320,8 +338,18 @@ def _test_checkpoint_zero_no_optimizer(args, load_optimizer_states=False) -@pytest.mark.parametrize("zero_stage", [0, 1, 2]) -def test_checkpoint_lr_scheduler(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (0, + False), + (1, + False), + (2, + False), + (2, + True), + ]) +def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -339,7 +367,8 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload }, "scheduler": { "type": "WarmupLR", @@ -376,8 +405,18 @@ def _test_checkpoint_lr_scheduler(args, load_lr_scheduler_states=True) -@pytest.mark.parametrize("zero_stage", [0, 1, 2]) -def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (0, + False), + (1, + False), + (2, + False), + (2, + True), + ]) +def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 2, "steps_per_print": 1, @@ -391,7 +430,8 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload }, "scheduler": { "type": "WarmupLR", @@ -400,7 +440,7 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage): "warmup_max_lr": 0.001, "warmup_num_steps": 1000 } - } + }, } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 diff --git a/tests/unit/test_dynamic_loss_scale.py b/tests/unit/test_dynamic_loss_scale.py index e12386271450..7575d6b49454 100755 --- a/tests/unit/test_dynamic_loss_scale.py +++ b/tests/unit/test_dynamic_loss_scale.py @@ -191,7 +191,6 @@ def _test_unfused_no_overflow(args): model, optim, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) - expected_loss_scale = 2**8 expected_scale_window = 2 # Ensure the dynamic loss scaler is correctly configured. diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 320d026bdd83..69a76a85830d 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -217,8 +217,16 @@ def _test_adamw_fp16_empty_grad(args, model, hidden_dim): _test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim) -@pytest.mark.parametrize("zero_stage", [0, 1, 2]) -def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (1, + False), + (2, + False), + (2, + True), + ]) +def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 1, "steps_per_print": 1, @@ -246,7 +254,8 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload } } @@ -274,8 +283,16 @@ def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim): hidden_dim=hidden_dim) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_zero_static_scale(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (1, + False), + (2, + False), + (2, + True), + ]) +def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -290,7 +307,8 @@ def test_zero_static_scale(tmpdir, zero_stage): "loss_scale": 138. }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload } } args = args_from_dict(tmpdir, config_dict) @@ -363,8 +381,16 @@ def _test_zero_static_scale(args): _test_zero_static_scale(args) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_zero_allow_untested_optimizer(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (1, + False), + (2, + False), + (2, + True), + ]) +def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -372,7 +398,8 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage): "enabled": True, }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload }, "zero_allow_untested_optimizer": False } @@ -392,8 +419,16 @@ def _test_zero_allow_untested_optimizer(args): _test_zero_allow_untested_optimizer(args) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_zero_empty_partition(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (1, + False), + (2, + False), + (2, + True), + ]) +def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_micro_batch_size_per_gpu": 1, "gradient_accumulation_steps": 1, @@ -408,7 +443,8 @@ def test_zero_empty_partition(tmpdir, zero_stage): } }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload } } args = args_from_dict(tmpdir, config_dict)