Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion deepspeed/pt/deepspeed_light.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,17 @@ def allreduce_no_retain(self, bucket, numel_per_bucket=500000000):
def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000):
grads = []
for param_name, param in self.module.named_parameters():
if param.grad is not None:
if param.grad is None:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
grads.append(
torch.zeros(param.size(),
dtype=param.dtype,
device=param.device))
else:
grad_data = param.grad.data
if self.sparse_gradients_enabled(
) and param_name in self.csr_tensor_module_names:
Expand Down
11 changes: 8 additions & 3 deletions tests/unit/simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@


class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False):
def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
if empty_grad:
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)])
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
self.rank = rank
self.empty_grad = empty_grad

def forward(self, x, y):
hidden_dim = x
hidden_dim = self.linear(hidden_dim)
if self.rank == 0 and self.empty_grad:
hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
else:
hidden_dim = self.linear(hidden_dim)
return self.cross_entropy_loss(hidden_dim, y)


Expand Down
87 changes: 82 additions & 5 deletions tests/unit/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ def _test_lamb_fp32_grad_clip(args, model, hidden_dim):
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
loss = model(batch[0].float(), batch[1])
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

Expand Down Expand Up @@ -81,7 +82,7 @@ def _test_lamb_fp16_basic(args, model, hidden_dim):

def test_lamb_fp16_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 1,
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Lamb",
Expand All @@ -97,9 +98,9 @@ def test_lamb_fp16_empty_grad(tmpdir):
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=True)
model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank)

@distributed_test(world_size=[1])
@distributed_test(world_size=[2])
def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
Expand All @@ -116,6 +117,44 @@ def _test_lamb_fp16_empty_grad(args, model, hidden_dim):
_test_lamb_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim)


def test_adam_fp32_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": False
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank)

@distributed_test(world_size=[2])
def _test_adam_fp32_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

_test_adam_fp32_empty_grad(args=args, model=model, hidden_dim=hidden_dim)


def test_adamw_fp16_basic(tmpdir):
config_dict = {
"train_batch_size": 1,
Expand Down Expand Up @@ -495,3 +534,41 @@ def _test_adam_amp_o2(args, model, hidden_dim):
model.step()

_test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim)


def test_adam_amp_o2_empty_grad(tmpdir):
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"gradient_clipping": 1.0,
"amp": {
"enabled": True,
"opt_level": "O2"
}
}
args = args_from_dict(tmpdir, config_dict)
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False, rank=args.local_rank)

@distributed_test(world_size=[2])
def _test_adam_amp_o2_empty_grad(args, model, hidden_dim):
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=50,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()

_test_adam_amp_o2_empty_grad(args=args, model=model, hidden_dim=hidden_dim)