Skip to content
Merged
22 changes: 14 additions & 8 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import torch
import warnings
import torch.distributed as dist

import apex
from apex import amp
from torch.nn.modules import Module
from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter

from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
Expand Down Expand Up @@ -264,7 +267,7 @@ def train_micro_batch_size_per_gpu(self):
return self._config.train_micro_batch_size_per_gpu

def optimizer_name(self):
return self._config.optimizer_name
return self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name

def optimizer_params(self):
return self._config.optimizer_params
Expand Down Expand Up @@ -295,7 +298,7 @@ def zero_cpu_offload(self):

def deepspeed_adam(self):
return self._config.zero_config.deepspeed_adam

def zero_optimization_stage(self):
return self._config.zero_optimization_stage

Expand Down Expand Up @@ -506,7 +509,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

if self.zero_optimization():
assert not self.amp_enabled(), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2"
if self.optimizer_name() not in [ADAM_OPTIMIZER]:
if not is_zero_supported_optimizer(basic_optimizer):
assert self.zero_allow_untested_optimizer(), \
'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.'

Expand Down Expand Up @@ -536,11 +539,13 @@ def _configure_basic_optimizer(self, model_parameters):
)
if self.optimizer_name() == ADAM_OPTIMIZER:
if self.zero_cpu_offload():
if False: #self.deepspeed_adam():
optimizer = DeepSpeedCPUAdam(model_parameters, **optimizer_parameters)
if False: #self.deepspeed_adam():
optimizer = DeepSpeedCPUAdam(model_parameters,
**optimizer_parameters)
else:
optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters)

optimizer = torch.optim.Adam(model_parameters,
**optimizer_parameters)

else:
from apex.optimizers.fused_adam import FusedAdam
optimizer = FusedAdam(model_parameters, **optimizer_parameters)
Expand All @@ -555,7 +560,8 @@ def _configure_fp16_optimizer(self, optimizer):
initial_dynamic_scale = self.initial_dynamic_scale()
dynamic_loss_args = self.dynamic_loss_scale_args()
clip_grad = self.gradient_clipping()
if self.optimizer_name() == ADAM_OPTIMIZER:

if isinstance(optimizer, apex.optimizers.FusedAdam):
if self.dynamic_loss_scale():
logger.info('Creating fp16 optimizer with dynamic loss scale')
timers = self.timers if self.wall_clock_breakdown() else None
Expand Down
9 changes: 8 additions & 1 deletion deepspeed/runtime/zero/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.distributed as dist

import apex
from deepspeed.utils import logger


Expand All @@ -20,3 +20,10 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
if rank in ranks:
my_group = group
return my_group


ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, apex.optimizers.FusedAdam]


def is_zero_supported_optimizer(optimizer):
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
1 change: 0 additions & 1 deletion tests/model/Megatron_GPT2/ds_config_func_bs4_zero1.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"zero_optimization": {
"stage": 1
},
"zero_allow_untested_optimizer": true,
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand Down
1 change: 0 additions & 1 deletion tests/model/Megatron_GPT2/ds_config_func_bs4_zero2.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"allgather_bucket_size": 7000000,
"reduce_scatter": true
},
"zero_allow_untested_optimizer": true,
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
"reduce_scatter": true,
"cpu_offload": true
},
"zero_allow_untested_optimizer": true,
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand Down
11 changes: 2 additions & 9 deletions tests/model/Megatron_GPT2/ds_config_func_bs8_zero0_gas3.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
{
"train_micro_batch_size_per_gpu":8,
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 3,
"steps_per_print": 1,
"zero_optimization": {
"stage":0,
"stage": 0,
"reduce_bucket_size": 7000000,
"allgather_bucket_size": 7000000,
"reduce_scatter": true
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand All @@ -26,5 +20,4 @@
"partition_activations": true,
"contiguous_memory_optimization": true
}

}
1 change: 0 additions & 1 deletion tests/model/Megatron_GPT2/ds_config_func_bs8_zero1.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"zero_optimization": {
"stage": 1
},
"zero_allow_untested_optimizer": true,
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand Down
11 changes: 2 additions & 9 deletions tests/model/Megatron_GPT2/ds_config_func_bs8_zero2_gas3.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
{
"train_micro_batch_size_per_gpu":8,
"train_micro_batch_size_per_gpu": 8,
"gradient_accumulation_steps": 3,
"steps_per_print": 1,
"zero_optimization": {
"stage":2,
"stage": 2,
"reduce_bucket_size": 7000000,
"allgather_bucket_size": 7000000,
"reduce_scatter": true
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand All @@ -26,5 +20,4 @@
"partition_activations": true,
"contiguous_memory_optimization": true
}

}
7 changes: 0 additions & 7 deletions tests/model/Megatron_GPT2/ds_config_perf_bs16.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@
"zero_optimization": {
"stage": 1
},
"zero_allow_untested_optimizer": true,
"disable_allgather": true,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand Down
7 changes: 0 additions & 7 deletions tests/model/Megatron_GPT2/ds_config_perf_bs32.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@
"zero_optimization": {
"stage": 1
},
"zero_allow_untested_optimizer": true,
"disable_allgather": true,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand Down
7 changes: 0 additions & 7 deletions tests/model/Megatron_GPT2/ds_config_perf_bs8.json
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@
"zero_optimization": {
"stage": 1
},
"zero_allow_untested_optimizer": true,
"disable_allgather": true,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
Expand Down
21 changes: 21 additions & 0 deletions tests/model/Megatron_GPT2/run_func_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def setUp(self):
def tearDown(self):
os.chdir(self.save_dir)

def test_mp1_gpu2_node1_fp16(self):
test_config = {
"mp": 1,
"gpus": 2,
"nodes": 1,
"bs": 8,
"steps": 1000,
"layers": LAYERS,
"hidden_size": HIDDEN_SIZE,
"seq_length": SEQ_LEN,
"heads": ATTN_HEADS,
"deepspeed": False,
"json": "ds_config_func_bs8_no_zero.json",
}

succ = self.run_test(test_config, 0.01)
self.assertTrue(succ)

def test_mp1_gpu1_node1_zero1(self):
test_config = {
"mp": 1,
Expand Down Expand Up @@ -348,6 +366,9 @@ def check_parity(self, base_file, test_file, r_tol):

def suite():
suite = unittest.TestSuite()

suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_fp16'))

suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero1'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero1'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero1'))
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import apex
import deepspeed
import argparse
import pytest
Expand Down Expand Up @@ -608,3 +609,38 @@ def _test_adam_amp_o2_empty_grad(args, model, hidden_dim):
model.step()

_test_adam_amp_o2_empty_grad(args=args, model=model, hidden_dim=hidden_dim)


@pytest.mark.parametrize('zero_stage, optimizer_constructor',
[(1,
apex.optimizers.FusedAdam),
(2,
torch.optim.Adam),
(2,
apex.optimizers.FusedAdam)])
def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_constructor):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": zero_stage
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[1])
def _test_zero_supported_client_optimizer(args, model, optimizer_constructor):
client_optimizer = optimizer_constructor(params=model.parameters())
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
optimizer=client_optimizer)

_test_zero_supported_client_optimizer(args=args,
model=model,
optimizer_constructor=optimizer_constructor)