diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index d51e2126cda2..76d8d323f6f5 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -12,14 +12,14 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): def __init__(self, model_params, lr=1e-3, - bettas=(0.9, - 0.999), + betas=(0.9, + 0.999), eps=1e-8, weight_decay=0, amsgrad=False): default_args = dict(lr=lr, - betas=bettas, + betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) @@ -30,7 +30,7 @@ def __init__(self, global ds_opt_adam ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op') - ds_opt_adam.create_adam(self.opt_id, lr, bettas[0], bettas[1], eps, weight_decay) + ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay) def __setstate__(self, state): super(DeepSpeedCPUAdam, self).__setstate__(state) @@ -54,6 +54,7 @@ def step(self, closure=None, fp16_param_groups=None): state = self.state[p] # State initialization if len(state) == 0: + print(f'group {group_id} param {param_id} = {p.numel()}') state['step'] = 0 # gradient momentums state['exp_avg'] = torch.zeros_like(p.data, device='cpu') diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index 9ed9c07df054..d39c17f37161 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -15,6 +15,7 @@ from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS + from deepspeed.utils import logger #Toggle this to true to enable correctness test #with gradient partitioning and without @@ -155,6 +156,7 @@ def __init__(self, self.overlap_comm = overlap_comm or cpu_offload self.cpu_offload = cpu_offload + self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu' self.dp_process_group = dp_process_group @@ -1405,17 +1407,22 @@ def step(self, closure=None): self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups) #torch.set_num_threads(12) timers('optimizer_step').start() - #self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups) - self.optimizer.step() - #get rid of the fp32 gradients. Not needed anymore - if not self.cpu_offload: + if self.cpu_offload: + # self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups) + self.optimizer.step() + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id].data.copy_(fp32_partition.data) + else: + self.optimizer.step() + #get rid of the fp32 gradients. Not needed anymore for group in self.single_partition_of_fp32_groups: group.grad = None - for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): - fp16_partitions[partition_id].data.copy_(fp32_partition.data) + for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): + fp16_partitions[partition_id].data.copy_(fp32_partition.data) timers('optimizer_step').stop() + timers.log(names=['optimizer_step']) if self.cpu_offload: self.reset_cpu_buffers() diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index ba7cf8e87ebe..e51ad7c1eca1 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -2,6 +2,7 @@ import torch.distributed as dist import apex from deepspeed.utils import logger +from deepspeed.ops.adam import DeepSpeedCPUAdam def _initialize_parameter_parallel_groups(parameter_parallel_size=None): @@ -22,8 +23,15 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None): return my_group -ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, apex.optimizers.FusedAdam] +ZERO_SUPPORTED_OPTIMIZERS = [ + torch.optim.Adam, + apex.optimizers.FusedAdam, + DeepSpeedCPUAdam +] def is_zero_supported_optimizer(optimizer): + print( + f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}' + ) return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS diff --git a/tests/model/Megatron_GPT2/run_func_test.py b/tests/model/Megatron_GPT2/run_func_test.py index a61b0e9ba85b..552314a0d037 100755 --- a/tests/model/Megatron_GPT2/run_func_test.py +++ b/tests/model/Megatron_GPT2/run_func_test.py @@ -19,7 +19,7 @@ def grep_loss_from_file(file_name): loss = 0.0 - + print(f'grepping {file_name}') with open(file_name, 'r') as f: lines = f.readlines() line_filter = "validation loss at the end of training for test data | LM loss:" @@ -455,7 +455,7 @@ def run_partition_activations_test(self, test_config, r_tol): baseline_prefix += test_config["json"][0:-5] baseline_deepspeed_config = True - test_config["other_args"] = cpu_optimizer_flag + test_config["other_args"] = f"\"{cpu_optimizer_flag}\"" base_file = self.gen_output_name(test_config, baseline_prefix, baseline_config=baseline_deepspeed_config) @@ -565,7 +565,7 @@ def suite(): suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2_cpu_optimizer')) # Baseline = Megatron + Torch.Optim.Adam - # Test = Megatron + Torch.Optim.Adam + ZeRO-Offload + # Test = Megatron + DeepSpeedAdam + ZeRO-Offload suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2_offload')) suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2_offload')) suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2_offload')) diff --git a/tests/model/run_sanity_check.py b/tests/model/run_sanity_check.py index b7d12ba18f07..21c3bc5d9d43 100755 --- a/tests/model/run_sanity_check.py +++ b/tests/model/run_sanity_check.py @@ -30,7 +30,7 @@ def pytest_hack(runner_result): def test_megatron(): - runner = unittest.TextTestRunner(failfast=True) + runner = unittest.TextTestRunner(failfast=False) pytest_hack(runner.run(Megatron_GPT2.suite())) diff --git a/tests/perf/adam_test.py b/tests/perf/adam_test.py index 89b70e96e7ac..0f29cab4662e 100755 --- a/tests/perf/adam_test.py +++ b/tests/perf/adam_test.py @@ -3,7 +3,7 @@ import time device = 'cpu' -model_size = 1 * 1024 ** 3 +model_size = 1 * 1024**3 group_size = [model_size, 274432] param = [torch.nn.Parameter(torch.ones(size, device=device)) for size in group_size] diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 4ce0d36284e0..07a8f643d134 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -235,21 +235,24 @@ def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_sta load_optimizer_states=False) -@pytest.mark.parametrize('zero_stage, use_cpu_offload', +@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [ (1, - False), + False, + 'Adam'), (2, - False), + False, + 'Adam'), (2, - True), + True, + 'deepspeed_adam'), ]) -def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload): +def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": adam_optimizer, "params": { "lr": 0.00015, "betas": [0.8, @@ -285,21 +288,27 @@ def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_stat load_optimizer_states=True) -@pytest.mark.parametrize('zero_stage, use_cpu_offload', +@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [ (1, - False), + False, + "Adam"), (2, - False), + False, + "Adam"), (2, - True), + True, + 'deepspeed_adam'), ]) -def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload): +def test_checkpoint_zero_no_optimizer(tmpdir, + zero_stage, + use_cpu_offload, + adam_optimizer): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": adam_optimizer, "params": { "lr": 0.00015, "betas": [0.8, @@ -338,23 +347,27 @@ def _test_checkpoint_zero_no_optimizer(args, load_optimizer_states=False) -@pytest.mark.parametrize('zero_stage, use_cpu_offload', +@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [ (0, - False), + False, + 'Adam'), (1, - False), + False, + 'Adam'), (2, - False), + False, + 'Adam'), (2, - True), + True, + 'deepspeed_adam'), ]) -def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): +def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": adam_optimizer, "params": { "lr": 0.00015, "betas": [0.8, @@ -405,23 +418,27 @@ def _test_checkpoint_lr_scheduler(args, load_lr_scheduler_states=True) -@pytest.mark.parametrize('zero_stage, use_cpu_offload', +@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer', [ (0, - False), + False, + 'Adam'), (1, - False), + False, + 'Adam'), (2, - False), + False, + 'Adam'), (2, - True), + True, + 'deepspeed_adam'), ]) -def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload): +def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer): config_dict = { "train_batch_size": 2, "steps_per_print": 1, "optimizer": { - "type": "Adam", + "type": adam_optimizer, "params": { "lr": 1e-5 }