Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ def _set_distributed_vars(self):
# Configure based on command line arguments
def _configure_with_arguments(self, args, mpu):
self.local_rank = args.local_rank if hasattr(args, 'local_rank') else 0
self._config = DeepSpeedConfig(args.deepspeed_config,
mpu,
param_dict=self.config_params)
config_file = args.deepspeed_config if hasattr(args,
'deepspeed_config') else None
self._config = DeepSpeedConfig(config_file, mpu, param_dict=self.config_params)

# Validate command line arguments
def _do_args_sanity_check(self, args):
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,21 @@ def create_config_from_dict(tmpdir, config_dict):
return config_path


def args_from_dict(tmpdir, config_dict):
config_path = create_config_from_dict(tmpdir, config_dict)
def create_deepspeed_args():
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepspeed = True
args.deepspeed_config = config_path
if torch.distributed.is_initialized():
# We assume up to one full node executing unit tests
assert torch.distributed.get_world_size() <= torch.cuda.device_count()
args.local_rank = torch.distributed.get_rank()
else:
args.local_rank = 0
return args


def args_from_dict(tmpdir, config_dict):
args = create_deepspeed_args()
config_path = create_config_from_dict(tmpdir, config_dict)
args.deepspeed_config = config_path
return args
37 changes: 36 additions & 1 deletion tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from deepspeed.ops.adam import FusedAdam
from common import distributed_test
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict
from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args

try:
from apex import amp
Expand Down Expand Up @@ -194,6 +194,41 @@ def _test_adamw_fp16_basic(args, model, hidden_dim):
_test_adamw_fp16_basic(args=args, model=model, hidden_dim=hidden_dim)


def test_dict_config_adamw_fp16_basic():
config_dict = {
"train_batch_size": 1,
"steps_per_print": 1,
"fp16": {
"enabled": True
}
}
args = create_deepspeed_args()
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)

@distributed_test(world_size=[1])
def _test_adamw_fp16_basic(args, model, hidden_dim, config_dict):
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer,
config_params=config_dict)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

_test_adamw_fp16_basic(args=args,
model=model,
hidden_dim=hidden_dim,
config_dict=config_dict)


def test_adamw_fp16_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 1,
Expand Down