Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions deepspeed/ops/adam/cpu_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
def __init__(self,
model_params,
lr=1e-3,
bettas=(0.9,
0.999),
betas=(0.9,
0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False):

default_args = dict(lr=lr,
betas=bettas,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad)
Expand All @@ -30,7 +30,7 @@ def __init__(self,

global ds_opt_adam
ds_opt_adam = importlib.import_module('deepspeed.ops.adam.cpu_adam_op')
ds_opt_adam.create_adam(self.opt_id, lr, bettas[0], bettas[1], eps, weight_decay)
ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay)

def __setstate__(self, state):
super(DeepSpeedCPUAdam, self).__setstate__(state)
Expand All @@ -54,6 +54,7 @@ def step(self, closure=None, fp16_param_groups=None):
state = self.state[p]
# State initialization
if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0
# gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, device='cpu')
Expand Down
19 changes: 13 additions & 6 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS

from deepspeed.utils import logger
#Toggle this to true to enable correctness test
#with gradient partitioning and without
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(self,
self.overlap_comm = overlap_comm or cpu_offload

self.cpu_offload = cpu_offload

self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'

self.dp_process_group = dp_process_group
Expand Down Expand Up @@ -1405,17 +1407,22 @@ def step(self, closure=None):
self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
#torch.set_num_threads(12)
timers('optimizer_step').start()
#self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
if not self.cpu_offload:
if self.cpu_offload:
# self.optimizer.step(fp16_param_groups=self.parallel_partitioned_fp16_groups)
self.optimizer.step()
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
else:
self.optimizer.step()
#get rid of the fp32 gradients. Not needed anymore
for group in self.single_partition_of_fp32_groups:
group.grad = None

for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)

timers('optimizer_step').stop()
timers.log(names=['optimizer_step'])

if self.cpu_offload:
self.reset_cpu_buffers()
Expand Down
10 changes: 9 additions & 1 deletion deepspeed/runtime/zero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.distributed as dist
import apex
from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam


def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
Expand All @@ -22,8 +23,15 @@ def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
return my_group


ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, apex.optimizers.FusedAdam]
ZERO_SUPPORTED_OPTIMIZERS = [
torch.optim.Adam,
apex.optimizers.FusedAdam,
DeepSpeedCPUAdam
]


def is_zero_supported_optimizer(optimizer):
print(
f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}'
)
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
6 changes: 3 additions & 3 deletions tests/model/Megatron_GPT2/run_func_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def grep_loss_from_file(file_name):
loss = 0.0

print(f'grepping {file_name}')
with open(file_name, 'r') as f:
lines = f.readlines()
line_filter = "validation loss at the end of training for test data | LM loss:"
Expand Down Expand Up @@ -455,7 +455,7 @@ def run_partition_activations_test(self, test_config, r_tol):
baseline_prefix += test_config["json"][0:-5]
baseline_deepspeed_config = True

test_config["other_args"] = cpu_optimizer_flag
test_config["other_args"] = f"\"{cpu_optimizer_flag}\""
base_file = self.gen_output_name(test_config,
baseline_prefix,
baseline_config=baseline_deepspeed_config)
Expand Down Expand Up @@ -565,7 +565,7 @@ def suite():
suite.addTest(GPT2FuncTestCase('test_mp4_gpu4_node1_zero2_cpu_optimizer'))

# Baseline = Megatron + Torch.Optim.Adam
# Test = Megatron + Torch.Optim.Adam + ZeRO-Offload
# Test = Megatron + DeepSpeedAdam + ZeRO-Offload
suite.addTest(GPT2FuncTestCase('test_mp1_gpu1_node1_zero2_offload'))
suite.addTest(GPT2FuncTestCase('test_mp1_gpu2_node1_zero2_offload'))
suite.addTest(GPT2FuncTestCase('test_mp2_gpu4_node1_zero2_offload'))
Expand Down
2 changes: 1 addition & 1 deletion tests/model/run_sanity_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def pytest_hack(runner_result):


def test_megatron():
runner = unittest.TextTestRunner(failfast=True)
runner = unittest.TextTestRunner(failfast=False)
pytest_hack(runner.run(Megatron_GPT2.suite()))


Expand Down
2 changes: 1 addition & 1 deletion tests/perf/adam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

device = 'cpu'
model_size = 1 * 1024 ** 3
model_size = 1 * 1024**3
group_size = [model_size, 274432]

param = [torch.nn.Parameter(torch.ones(size, device=device)) for size in group_size]
Expand Down
69 changes: 43 additions & 26 deletions tests/unit/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,21 +235,24 @@ def _test_checkpoint_fused_optimizer(args, model, hidden_dim, load_optimizer_sta
load_optimizer_states=False)


@pytest.mark.parametrize('zero_stage, use_cpu_offload',
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
[
(1,
False),
False,
'Adam'),
(2,
False),
False,
'Adam'),
(2,
True),
True,
'deepspeed_adam'),
])
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload):
def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"type": adam_optimizer,
"params": {
"lr": 0.00015,
"betas": [0.8,
Expand Down Expand Up @@ -285,21 +288,27 @@ def _test_checkpoint_zero_optimizer(args, model, hidden_dim, load_optimizer_stat
load_optimizer_states=True)


@pytest.mark.parametrize('zero_stage, use_cpu_offload',
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
[
(1,
False),
False,
"Adam"),
(2,
False),
False,
"Adam"),
(2,
True),
True,
'deepspeed_adam'),
])
def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload):
def test_checkpoint_zero_no_optimizer(tmpdir,
zero_stage,
use_cpu_offload,
adam_optimizer):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"type": adam_optimizer,
"params": {
"lr": 0.00015,
"betas": [0.8,
Expand Down Expand Up @@ -338,23 +347,27 @@ def _test_checkpoint_zero_no_optimizer(args,
load_optimizer_states=False)


@pytest.mark.parametrize('zero_stage, use_cpu_offload',
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
[
(0,
False),
False,
'Adam'),
(1,
False),
False,
'Adam'),
(2,
False),
False,
'Adam'),
(2,
True),
True,
'deepspeed_adam'),
])
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"type": adam_optimizer,
"params": {
"lr": 0.00015,
"betas": [0.8,
Expand Down Expand Up @@ -405,23 +418,27 @@ def _test_checkpoint_lr_scheduler(args,
load_lr_scheduler_states=True)


@pytest.mark.parametrize('zero_stage, use_cpu_offload',
@pytest.mark.parametrize('zero_stage, use_cpu_offload, adam_optimizer',
[
(0,
False),
False,
'Adam'),
(1,
False),
False,
'Adam'),
(2,
False),
False,
'Adam'),
(2,
True),
True,
'deepspeed_adam'),
])
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload, adam_optimizer):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"type": adam_optimizer,
"params": {
"lr": 1e-5
}
Expand Down