Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def initialize(args,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=True,
dist_init_required=None,
collate_fn=None):
r"""Initialize the DeepSpeed Engine.

Expand All @@ -56,7 +56,8 @@ def initialize(args,
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()

dist_init_required: Optional: Initializes torch.distributed
dist_init_required: Optional: None will auto-initialize torch.distributed if needed,
otherwise the user can force it to be initialized or not via boolean.

collate_fn: Optional: Merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
Expand Down
15 changes: 13 additions & 2 deletions deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self,
training_data=None,
lr_scheduler=None,
mpu=None,
dist_init_required=True,
dist_init_required=None,
collate_fn=None):
super(DeepSpeedLight, self).__init__()

Expand All @@ -119,8 +119,19 @@ def __init__(self,
self.gradient_average = True
self.warn_unscaled_loss = True

if dist_init_required is None:
dist_init_required = not dist.is_initialized()

self.dist_backend = "nccl"
if dist_init_required:
dist.init_process_group(backend="nccl")
if not dist.is_initialized():
logging.info("Initializing torch distributed with backend: {}".format(
self.dist_backend))
dist.init_process_group(backend=self.dist_backend)
else:
logging.warning(
"Was given dist_init_required=True but detected that torch"
"distributed was already initialized, cannot initialize twice.")

self._do_args_sanity_check(args)
self._configure_with_arguments(args, mpu)
Expand Down
45 changes: 43 additions & 2 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,7 @@ def test_deprecated_deepscale_config(tmpdir):
def _test_deprecated_deepscale_config(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=False)
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
Expand All @@ -154,3 +153,45 @@ def _test_deprecated_deepscale_config(args, model, hidden_dim):
model.step()

_test_deprecated_deepscale_config(args=args, model=model, hidden_dim=hidden_dim)


def test_dist_init_true(tmpdir):
config_dict = {
"train_batch_size": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"fp16": {
"enabled": True
}
}

config_path = create_config_from_dict(tmpdir, config_dict)
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepscale_config = config_path
args.local_rank = 0

hidden_dim = 10

model = SimpleModel(hidden_dim)

@distributed_test(world_size=[1])
def _test_dist_init_true(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=True)
data_loader = random_dataloader(model=model,
total_samples=5,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

_test_dist_init_true(args=args, model=model, hidden_dim=hidden_dim)
12 changes: 4 additions & 8 deletions tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def test_lamb_fp16_basic(tmpdir):
def _test_lamb_fp16_basic(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=False)
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -70,8 +69,7 @@ def test_lamb_fp16_empty_grad(tmpdir):
def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=False)
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -102,8 +100,7 @@ def _test_adamw_fp16_basic(args, model, hidden_dim):
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer,
dist_init_required=False)
optimizer=optimizer)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
Expand Down Expand Up @@ -134,8 +131,7 @@ def _test_adamw_fp16_empty_grad(args, model, hidden_dim):
optimizer = torch.optim.AdamW(params=model.parameters())
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
optimizer=optimizer,
dist_init_required=False)
optimizer=optimizer)
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
Expand Down