diff --git a/DeepSpeedExamples b/DeepSpeedExamples index 47766e0d7758..a940b3347a64 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit 47766e0d7758a53c2582e517ae1c930bcc012518 +Subproject commit a940b3347a64da65f41fefee39e6d19748490ed1 diff --git a/deepspeed/pt/deepspeed_checkpointing.py b/deepspeed/pt/deepspeed_checkpointing.py index 2a5bb2ab688b..746adae1c599 100755 --- a/deepspeed/pt/deepspeed_checkpointing.py +++ b/deepspeed/pt/deepspeed_checkpointing.py @@ -602,11 +602,11 @@ def reset(): size_offsets = [] -def _configure_using_config_file(deepspeed_config): +def _configure_using_config_file(deepspeed_config, mpu=None): global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ PA_TO_CPU, SYNCHRONIZE, PROFILE_TIME - config = DeepSpeedConfig(deepspeed_config).activation_checkpointing_config + config = DeepSpeedConfig(deepspeed_config, mpu=mpu).activation_checkpointing_config logger.info(config.repr()) PARTITION_ACTIVATIONS = config.partition_activations CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization @@ -684,12 +684,12 @@ def configure( _configure_defaults() - if deepspeed_config is not None: - _configure_using_config_file(deepspeed_config) - if mpu_ is not None: mpu = mpu_ + if deepspeed_config is not None: + _configure_using_config_file(deepspeed_config, mpu=mpu) + if partition_activations is not None: PARTITION_ACTIVATIONS = partition_activations diff --git a/deepspeed/pt/deepspeed_config.py b/deepspeed/pt/deepspeed_config.py index f618124d74c7..7b8e43c19583 100755 --- a/deepspeed/pt/deepspeed_config.py +++ b/deepspeed/pt/deepspeed_config.py @@ -16,7 +16,8 @@ TENSOR_CORE_ALIGN_SIZE = 8 ADAM_OPTIMIZER = 'adam' LAMB_OPTIMIZER = 'lamb' -DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER] +TORCH_ADAM_OPTIMIZER = 'torch_adam' +DEEPSPEED_OPTIMIZERS = [ADAM_OPTIMIZER, LAMB_OPTIMIZER, TORCH_ADAM_OPTIMIZER] def get_amp_enabled(param_dict): @@ -457,12 +458,18 @@ def _do_error_check(self): if self.zero_enabled: assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled" assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) + if self.zero_config.cpu_offload is True: + assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS) assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) - assert self.gradient_accumulation_steps, 'DeepSpeedConfig: {} is not defined'.format( + assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format( GRADIENT_ACCUMULATION_STEPS) + if self.optimizer_name == TORCH_ADAM_OPTIMIZER: + assert self.zero_enabled, "ZeRO is not enabled with using TORCH_ADAM_OPTIMIZER" + assert self.zero_config.cpu_offload, " cpu_offload is not enabled with using TORCH_ADAM_OPTIMIZER" + def _do_warning_check(self): fp16_enabled = self.fp16_enabled or self.zero_enabled diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 20d680796127..fa8ee6cb36e7 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -22,7 +22,7 @@ from deepspeed.pt.fp16_unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.pt.deepspeed_fused_lamb import FusedLamb from deepspeed.pt.deepspeed_config import DeepSpeedConfig, \ - ADAM_OPTIMIZER, LAMB_OPTIMIZER, DEEPSPEED_OPTIMIZERS + ADAM_OPTIMIZER, LAMB_OPTIMIZER, TORCH_ADAM_OPTIMIZER, DEEPSPEED_OPTIMIZERS from deepspeed.pt.deepspeed_dataloader import DeepSpeedDataLoader from deepspeed.pt.deepspeed_constants import \ @@ -107,7 +107,6 @@ def __init__(self, collate_fn=None, config_params=None): super(DeepSpeedLight, self).__init__() - self.client_optimizer = optimizer self.client_model_parameters = model_parameters self.client_lr_scheduler = lr_scheduler @@ -293,6 +292,9 @@ def zero_reduce_scatter(self): def zero_overlap_comm(self): return self._config.zero_config.overlap_comm + def zero_cpu_offload(self): + return self._config.zero_config.cpu_offload + def zero_optimization_stage(self): return self._config.zero_optimization_stage @@ -492,6 +494,7 @@ def _configure_distributed_model(self, model): # Configure optimizer def _configure_optimizer(self, client_optimizer, model_parameters): + if client_optimizer is not None: basic_optimizer = client_optimizer logger.info('Using client Optimizer as basic optimizer') @@ -505,13 +508,21 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.zero_optimization(): assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" - if self.optimizer_name() != ADAM_OPTIMIZER: + if self.optimizer_name() not in [ADAM_OPTIMIZER, TORCH_ADAM_OPTIMIZER]: assert self.zero_allow_untested_optimizer(), \ 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' logger.warning( "**** You are using ZeRO with an untested optimizer, proceed with caution *****" ) + if self.zero_cpu_offload(): + if self.optimizer_name() != TORCH_ADAM_OPTIMIZER: + assert self.zero_allow_untested_optimizer(), \ + 'You are using ZeRO-Offload with an untested Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' + + logger.warning( + "**** You are using ZeRO-Offload with an untested optimizer, proceed with caution *****" + ) self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" @@ -523,8 +534,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters): self.optimizer = self._configure_fp16_optimizer(basic_optimizer) else: self.optimizer = basic_optimizer - - # logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict())) + logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer)) + logger.info('DeepSpeed Final Optimizer = {}'.format(self.optimizer.state_dict())) def _configure_basic_optimizer(self, model_parameters): optimizer_parameters = self.optimizer_params() @@ -537,6 +548,8 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = FusedAdam(model_parameters, **optimizer_parameters) elif self.optimizer_name() == LAMB_OPTIMIZER: optimizer = FusedLamb(model_parameters, **optimizer_parameters) + elif self.optimizer_name() == TORCH_ADAM_OPTIMIZER: + optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) optimizer = torch_optimizer(model_parameters, **optimizer_parameters) @@ -613,6 +626,7 @@ def _configure_zero_optimizer(self, optimizer): dp_process_group=self.data_parallel_group, reduce_scatter=self.zero_reduce_scatter(), overlap_comm=self.zero_overlap_comm(), + cpu_offload=self.zero_cpu_offload(), mpu=self.mpu, postscale_gradients=self.postscale_gradients(), gradient_predivide_factor=self.gradient_predivide_factor()) @@ -843,7 +857,6 @@ def step(self): master_params = amp.master_params(self.optimizer) torch.nn.utils.clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping()) - self.optimizer.step() #zero grad in basic optimizer could be unreliable and may not exhibit @@ -946,6 +959,9 @@ def _get_optimizer_param(self, param_name): def get_lr(self): return self._get_optimizer_param('lr') + def get_type(self): + return self._get_optimizer_param('type') + def get_mom(self): return self._get_optimizer_param('betas') diff --git a/deepspeed/pt/deepspeed_zero_config.py b/deepspeed/pt/deepspeed_zero_config.py index 4f654d3b8c30..69b22d4d2bef 100755 --- a/deepspeed/pt/deepspeed_zero_config.py +++ b/deepspeed/pt/deepspeed_zero_config.py @@ -24,6 +24,7 @@ "overlap_comm": [true|false], "reduce_bucket_size": 500000000 "load_from_fp32_weights": [true|false] + "cpu_offload": [true|false] } } ''' @@ -63,21 +64,22 @@ ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS = 'load_from_fp32_weights' ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT = True +ZERO_OPTIMIZATION_CPU_OFFLOAD = 'cpu_offload' +ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT = False + ZERO_OPTIMIZATION_DEFAULT = { - ZERO_OPTIMIZATION_STAGE: - ZERO_OPTIMIZATION_STAGE_DEFAULT, + ZERO_OPTIMIZATION_STAGE: ZERO_OPTIMIZATION_STAGE_DEFAULT, ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS: ZERO_OPTIMIZATION_CONTIGUOUS_GRADIENTS_DEFAULT, - ZERO_OPTIMIZATION_REDUCE_SCATTER: - ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, - ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: - ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, + ZERO_OPTIMIZATION_REDUCE_SCATTER: ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT, + ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE: ZERO_OPTIMIZATION_REDUCE_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS: ZERO_OPTIMIZATION_ALLGATHER_PARTITIONS_DEFAULT, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE: ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS: - ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT + ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT, + ZERO_OPTIMIZATION_CPU_OFFLOAD: ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT } @@ -93,6 +95,7 @@ def __init__(self, param_dict): self.allgather_bucket_size = None self.overlap_comm = None self.load_from_fp32_weights = None + self.cpu_offload = None if ZERO_OPTIMIZATION in param_dict.keys(): zero_config_dict = param_dict[ZERO_OPTIMIZATION] @@ -157,7 +160,12 @@ def _initialize(self, zero_config_dict): zero_config_dict, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE, ZERO_OPTIMIZATION_ALLGATHER_BUCKET_SIZE_DEFAULT) + self.load_from_fp32_weights = get_scalar_param( zero_config_dict, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS, ZERO_OPTIMIZATION_LOAD_FROM_FP32_WEIGHTS_DEFAULT) + + self.cpu_offload = get_scalar_param(zero_config_dict, + ZERO_OPTIMIZATION_CPU_OFFLOAD, + ZERO_OPTIMIZATION_CPU_OFFLOAD_DEFAULT) diff --git a/deepspeed/pt/deepspeed_zero_optimizer.py b/deepspeed/pt/deepspeed_zero_optimizer.py index cbfb249b501d..600394cc7802 100755 --- a/deepspeed/pt/deepspeed_zero_optimizer.py +++ b/deepspeed/pt/deepspeed_zero_optimizer.py @@ -9,11 +9,11 @@ import math from torch._six import inf from torch.autograd import Variable +import collections from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.pt.deepspeed_utils import see_memory_usage, is_model_parallel_parameter from deepspeed.pt.deepspeed_zero_config import ZERO_OPTIMIZATION_GRADIENTS - #Toggle this to true to enable correctness test #with gradient partitioning and without pg_correctness_test = False @@ -99,6 +99,37 @@ def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") +#jie:asyn move to target device +def async_migrate_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + v = obj.cuda(dev, async=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + target = torch.empty_like(obj, device=dev).copy_(obj) + if main_stream is not None: + target.data.record_stream(main_stream) + return target + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + + class FP16_DeepSpeedZeroOptimizer(object): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -123,6 +154,7 @@ def __init__(self, dp_process_group=None, reduce_scatter=True, overlap_comm=False, + cpu_offload=False, mpu=None, clip_grad=0.0, allreduce_always_fp32=False, @@ -151,6 +183,8 @@ def __init__(self, self.overlap_comm = overlap_comm + self.cpu_offload = cpu_offload + self.dp_process_group = dp_process_group self.partition_count = dist.get_world_size(group=self.dp_process_group) @@ -185,7 +219,8 @@ def __init__(self, #a single 32-bit partition of the parallel partitioned parameters #that this process will update self.single_partition_of_fp32_groups = [] - + if cpu_offload: + self.averaged_gradients_on_cpu = {} #param partition info #These are the parameters in each group that will not be updated by this process directly @@ -208,6 +243,7 @@ def __init__(self, # padding on each partition for alignment purposes self.groups_padding = [] + # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify @@ -253,13 +289,24 @@ def __init__(self, self.parallel_partitioned_fp16_groups.append(data_parallel_partitions) # a partition of the fp32 master weights that will be updated by this process - self.single_partition_of_fp32_groups.append( - self.parallel_partitioned_fp16_groups[i] - [partition_id].clone().float().detach()) + if self.cpu_offload: + self.single_partition_of_fp32_groups.append( + async_copy_to(self.parallel_partitioned_fp16_groups[i][partition_id], + 'cpu').float()) + self.averaged_gradients_on_cpu[i] = [ + torch.empty_like( + self.parallel_partitioned_fp16_groups[i][partition_id], + device='cpu') + ] + else: + self.single_partition_of_fp32_groups.append( + self.parallel_partitioned_fp16_groups[i] + [partition_id].clone().float().detach()) # modify optimizer of have flat master weight self.single_partition_of_fp32_groups[ i].requires_grad = True # keep this in case internal optimizer uses it + param_group['params'] = [self.single_partition_of_fp32_groups[i]] partition_size = len(self.fp16_groups_flat[i]) / dist.get_world_size( @@ -276,6 +323,8 @@ def __init__(self, self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) self.reduction_stream = torch.cuda.Stream() + self.cpu_computation_stream = torch.cuda.Stream() + self.migration_stream = torch.cuda.Stream() self.callback_queued = False self.param_dict = {} @@ -380,10 +429,16 @@ def _release_ipg_buffers(self): def initialize_optimizer_states(self): for i, group in enumerate(self.fp16_groups): - single_grad_partition = torch.zeros( - int(self.partition_size[i]), - dtype=self.single_partition_of_fp32_groups[i].dtype, - device=torch.cuda.current_device()) + if self.cpu_offload: + single_grad_partition = torch.zeros( + int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device='cpu') + else: + single_grad_partition = torch.zeros( + int(self.partition_size[i]), + dtype=self.single_partition_of_fp32_groups[i].dtype, + device=torch.cuda.current_device()) self.single_partition_of_fp32_groups[i].grad = single_grad_partition self.optimizer.step() @@ -445,14 +500,15 @@ def independent_gradient_partition_epilogue(self): if self.overlap_comm: torch.cuda.synchronize() - for i, _ in enumerate(self.fp16_groups): - self.averaged_gradients[i] = self.get_flat_partition( - self.params_in_partition[i], - self.first_offset[i], - self.partition_size[i], - dtype=torch.half, - device=torch.cuda.current_device(), - return_tensor_list=True) + if self.cpu_offload is False: + for i, _ in enumerate(self.fp16_groups): + self.averaged_gradients[i] = self.get_flat_partition( + self.params_in_partition[i], + self.first_offset[i], + self.partition_size[i], + dtype=torch.half, + device=torch.cuda.current_device(), + return_tensor_list=True) self._release_ipg_buffers() @@ -561,6 +617,7 @@ def report_ipg_memory_usage(self, tag, param_elems): ###############Idependent Partition Gradient ######################## def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) @@ -585,7 +642,15 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.elements_in_ipg_bucket, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) + #jie: + if self.cpu_offload: + with stream(self.migration_stream): + averaged_gradients_on_cpu[i][param_id] = async_copy_to( + new_grad_tensor.data.view_as(param.grad), + 'cpu', + self.cpu_computation_stream) + else: + param.grad.data = new_grad_tensor.data.view_as(param.grad) self.elements_in_ipg_bucket += param.numel() self.grads_in_ipg_bucket.append(param.grad) @@ -697,6 +762,15 @@ def copy_grads_in_partition(self, param): total_size += param_in_partition.numel() see_memory_usage(f"before copying {total_size} gradients into partition") + #jie: + ''' + if self.cpu_offload: + self.grads_in_partition = torch.empty(int(total_size), + dtype=torch.half, + device=torch.cuda.current_device()) + '' + else: + ''' self.grads_in_partition = torch.empty(int(total_size), dtype=torch.half, device=torch.cuda.current_device()) @@ -707,6 +781,11 @@ def copy_grads_in_partition(self, param): self.grads_in_partition_offset, param.numel()) new_grad_tensor.copy_(param.grad.view(-1)) + ''' + if self.cpu_offload: + averaged_gradients_on_cpu[self.get_param_id()] = new_grad_tensor.data.view_as(param.grad) + else: + ''' param.grad.data = new_grad_tensor.data.view_as(param.grad) self.grads_in_partition_offset += param.numel() @@ -1122,6 +1201,11 @@ def step(self, closure=None): partition_id = dist.get_rank(group=self.dp_process_group) for i, group in enumerate(self.fp16_groups): + if self.cpu_offload: + self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] + #norm_groups.append( + # self.get_grad_norm_direct(self.averaged_gradients[i], + # self.params_in_partition[i])) norm_groups.append( self.get_grad_norm_direct(self.averaged_gradients[i], self.params_in_partition[i])) @@ -1147,20 +1231,31 @@ def step(self, closure=None): #release all the gradient since we have already created a necessary copy in dp_grad_partition self.free_grad_in_param_list(self.params_in_partition[i]) - self.averaged_gradients[i] = None + if self.cpu_offload is False: + self.averaged_gradients[i] = None single_partition_grad_groups.append(single_grad_partition) self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) + self.cpu_computation_stream.wait_stream(self.migration_stream) timers('optimizer_step').start() - self.optimizer.step() - #get rid of the fp32 gradients. Not needed anymore - for group in self.single_partition_of_fp32_groups: - group.grad = None + with torch.cuda.stream(self.cpu_computation_stream): + self.optimizer.step() - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id].data.copy_(fp32_partition.data) + if self.cpu_offload: + stream = torch.cuda.current_stream() + with torch.cuda.stream(self.migration_stream): + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id] = async_copy_to( + fp32_partition, + torch.cuda.current_device(), + stream) + #for averaged_gradients_cpu, fp32_partition in zip(self.averaged_gradients_on_cpu, self.single_partition_of_fp32_groups): + # averaged_gradients_cpu = [fp32_partition] + else: + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id].data.copy_(fp32_partition.data) timers('optimizer_step').stop() timers('optimizer_allgather').start() @@ -1243,6 +1338,8 @@ def has_overflow_serial(self, params, is_grad_list=False): def has_overflow_partitioned_grads_serial(self): for i in range(len(self.fp16_groups)): + if self.cpu_offload: + self.averaged_gradients[i] = self.averaged_gradients_on_cpu[i] for j, grad in enumerate(self.averaged_gradients[i]): if grad is not None and self._has_inf_or_nan(grad.data, j): return True @@ -1317,7 +1414,7 @@ def backward(self, loss, retain_graph=False): device=torch.cuda.current_device()) self.ipg_buffer.append(buf_1) self.ipg_index = 0 - + torch.cuda.empty_cache() self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) def check_overflow(self, partition_gradients=True): @@ -1369,8 +1466,12 @@ def _get_groups_without_padding(self, groups_with_padding): def _get_state_without_padding(self, state_with_padding, padding): lean_state = {} for key, value in state_with_padding.items(): - lean_length = value.numel() - padding - lean_state[key] = value[:lean_length] + #jie: torch.optim.Adam() has "step" has a key in state_dict + if key == "step": + lean_state[key] = value + else: + lean_length = value.numel() - padding + lean_state[key] = value[:lean_length] return lean_state @@ -1412,6 +1513,13 @@ def state_dict(self): self.single_partition_of_fp32_groups) state_dict['single_partition_of_fp32_groups'] = fp32_groups_without_padding + if self.cpu_offload: + state_dict_tmp = async_copy_to(state_dict, + 'cpu', + torch.cuda.current_stream()) + state_dict = None + state_dict = state_dict_tmp + return state_dict # Restore base optimizer fp32 weights from checkpoint by: diff --git a/deepspeed/pt/deepspeed_zero_utils.py b/deepspeed/pt/deepspeed_zero_utils.py new file mode 100755 index 000000000000..034a9f4ccb22 --- /dev/null +++ b/deepspeed/pt/deepspeed_zero_utils.py @@ -0,0 +1,33 @@ +import torch +from torch.autograd import Variable +import collections + + +def async_migrate_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + v = obj.cuda(dev, async=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + obj = Variable(obj) + if isinstance(obj, Variable): + target = torch.empty_like(obj, device=dev).copy_(obj) + if main_stream is not None: + target.data.record_stream(main_stream) + return target + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index d80bd897ee1a..349648b37d60 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -235,13 +235,26 @@ def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_sta load_optimizer_states=False) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_checkpoint_zero_optimizer(tmpdir, zero_stage): +''' +@pytest.mark.parametrize('zero_stage, use_cpu_offload, optimizer_type', + [ + (1, + False, + "adam"), + (2, + False, + "adam"), + (2, + True, + "torch_adam"), + ]) +#@pytest.mark.parametrize("zero_stage", [1, 2]) +def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, optimizer_type): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": optimizer_type, "params": { "lr": 0.00015, "betas": [0.8, @@ -254,8 +267,9 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage - }, + "stage": zero_stage, + "cpu_offload": use_cpu_offload + } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -274,15 +288,31 @@ def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_stat model=model, hidden_dim=hidden_dim, load_optimizer_states=True) - - -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage): +''' + + +#@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize('zero_stage, use_cpu_offload, optimizer_type', + [ + (1, + False, + "adam"), + (2, + False, + "adam"), + (2, + True, + "torch_adam"), + ]) +def test_checkpoint_zero_no_optimizer(tmpdir, + zero_stage, + use_cpu_offload, + optimizer_type): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": optimizer_type, "params": { "lr": 0.00015, "betas": [0.8, @@ -295,8 +325,9 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage - }, + "stage": zero_stage, + "cpu_offload": use_cpu_offload + } } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -320,13 +351,28 @@ def _test_checkpoint_zero_no_optimizer(args, load_optimizer_states=False) -@pytest.mark.parametrize("zero_stage", [0, 1, 2]) -def test_checkpoint_lr_scheduler(tmpdir, zero_stage): +#@pytest.mark.parametrize("zero_stage", [0, 1, 2]) +@pytest.mark.parametrize('zero_stage, use_cpu_offload, optimizer_type', + [ + (0, + False, + "adam"), + (1, + False, + "adam"), + (2, + False, + "adam"), + (2, + True, + "torch_adam"), + ]) +def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, optimizer_type): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": optimizer_type, "params": { "lr": 0.00015, "betas": [0.8, @@ -339,7 +385,8 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload }, "scheduler": { "type": "WarmupLR", @@ -376,13 +423,28 @@ def _test_checkpoint_lr_scheduler(args, load_lr_scheduler_states=True) -@pytest.mark.parametrize("zero_stage", [0, 1, 2]) -def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage): +#@pytest.mark.parametrize("zero_stage", [0, 1, 2]) +@pytest.mark.parametrize('zero_stage, use_cpu_offload, optimizer_type', + [ + (0, + False, + "adam"), + (1, + False, + "adam"), + (2, + False, + "adam"), + (2, + True, + "torch_adam"), + ]) +def cpu_offload_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, optimizer_type): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": optimizer_type, "params": { "lr": 1e-5 } @@ -391,7 +453,8 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload }, "scheduler": { "type": "WarmupLR", @@ -400,7 +463,7 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage): "warmup_max_lr": 0.001, "warmup_num_steps": 1000 } - } + }, } args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 @@ -408,11 +471,11 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage): model = SimpleModel(hidden_dim, empty_grad=False) @distributed_test(world_size=[2]) - def _test_checkpoint_no_lr_scheduler(args, - model, - hidden_dim, - load_optimizer_states, - load_lr_scheduler_states): + def _cpu_offload_lr_scheduler(args, + model, + hidden_dim, + load_optimizer_states, + load_lr_scheduler_states): checkpoint_correctness_verification( args, model, diff --git a/tests/unit/test_dynamic_loss_scale.py b/tests/unit/test_dynamic_loss_scale.py index e12386271450..7575d6b49454 100755 --- a/tests/unit/test_dynamic_loss_scale.py +++ b/tests/unit/test_dynamic_loss_scale.py @@ -191,7 +191,6 @@ def _test_unfused_no_overflow(args): model, optim, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters()) - expected_loss_scale = 2**8 expected_scale_window = 2 # Ensure the dynamic loss scaler is correctly configured. diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 320d026bdd83..8464a763e021 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -217,13 +217,27 @@ def _test_adamw_fp16_empty_grad(args, model, hidden_dim): _test_adamw_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim) -@pytest.mark.parametrize("zero_stage", [0, 1, 2]) -def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload, optimizer_type', + [ + (1, + False, + "adam"), + (2, + False, + "adam"), + (2, + True, + "torch_adam"), + ]) +def test_adam_fp16_zero_onecycle_compatibility(tmpdir, + zero_stage, + use_cpu_offload, + optimizer_type): config_dict = { "train_batch_size": 1, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": optimizer_type, "params": { "lr": 0.00015 } @@ -246,7 +260,8 @@ def test_adam_fp16_zero_onecycle_compatibility(tmpdir, zero_stage): "enabled": True }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload } } @@ -274,13 +289,25 @@ def _test_adam_fp16_zero_onecycle_compatibility(args, model, hidden_dim): hidden_dim=hidden_dim) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_zero_static_scale(tmpdir, zero_stage): +#@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize('zero_stage, use_cpu_offload, optimizer_type', + [ + (1, + False, + "adam"), + (2, + False, + "adam"), + (2, + True, + "torch_adam"), + ]) +def test_zero_static_scale(tmpdir, zero_stage, use_cpu_offload, optimizer_type): config_dict = { "train_batch_size": 4, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": optimizer_type, "params": { "lr": 0.00015 } @@ -290,7 +317,8 @@ def test_zero_static_scale(tmpdir, zero_stage): "loss_scale": 138. }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload } } args = args_from_dict(tmpdir, config_dict) @@ -363,8 +391,16 @@ def _test_zero_static_scale(args): _test_zero_static_scale(args) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_zero_allow_untested_optimizer(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload', + [ + (1, + False), + (2, + False), + (2, + True), + ]) +def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): config_dict = { "train_batch_size": 4, "steps_per_print": 1, @@ -372,7 +408,8 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage): "enabled": True, }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload }, "zero_allow_untested_optimizer": False } @@ -392,8 +429,19 @@ def _test_zero_allow_untested_optimizer(args): _test_zero_allow_untested_optimizer(args) -@pytest.mark.parametrize("zero_stage", [1, 2]) -def test_zero_empty_partition(tmpdir, zero_stage): +@pytest.mark.parametrize('zero_stage, use_cpu_offload, optimizer_type', + [ + (1, + False, + "adam"), + (2, + False, + "adam"), + (2, + True, + "torch_adam"), + ]) +def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload, optimizer_type): config_dict = { "train_micro_batch_size_per_gpu": 1, "gradient_accumulation_steps": 1, @@ -402,13 +450,14 @@ def test_zero_empty_partition(tmpdir, zero_stage): "initial_scale_power": 8 }, "optimizer": { - "type": "Adam", + "type": optimizer_type, "params": { "lr": 0.00015 } }, "zero_optimization": { - "stage": zero_stage + "stage": zero_stage, + "cpu_offload": use_cpu_offload } } args = args_from_dict(tmpdir, config_dict)