diff --git a/deepspeed/pt/deepspeed_light.py b/deepspeed/pt/deepspeed_light.py index 1c8623fa27e7..ce27d7d68f49 100755 --- a/deepspeed/pt/deepspeed_light.py +++ b/deepspeed/pt/deepspeed_light.py @@ -979,7 +979,17 @@ def allreduce_no_retain(self, bucket, numel_per_bucket=500000000): def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): grads = [] for param_name, param in self.module.named_parameters(): - if param.grad is not None: + if param.grad is None: + # In cases where there is an imbalance of empty grads across + # ranks we must create empty grads, this will ensure that every + # rank is reducing the same size. In some cases it may make + # sense in the future to support the ability to average not + # w.r.t. world size but with a different value. + grads.append( + torch.zeros(param.size(), + dtype=param.dtype, + device=param.device)) + else: grad_data = param.grad.data if self.sparse_gradients_enabled( ) and param_name in self.csr_tensor_module_names: diff --git a/tests/unit/simple_model.py b/tests/unit/simple_model.py index c17ef493cb2a..7a2e3357af60 100644 --- a/tests/unit/simple_model.py +++ b/tests/unit/simple_model.py @@ -5,16 +5,21 @@ class SimpleModel(torch.nn.Module): - def __init__(self, hidden_dim, empty_grad=False): + def __init__(self, hidden_dim, empty_grad=False, rank=0): super(SimpleModel, self).__init__() self.linear = torch.nn.Linear(hidden_dim, hidden_dim) if empty_grad: - self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim)]) + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + self.rank = rank + self.empty_grad = empty_grad def forward(self, x, y): hidden_dim = x - hidden_dim = self.linear(hidden_dim) + if self.rank == 0 and self.empty_grad: + hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim) + else: + hidden_dim = self.linear(hidden_dim) return self.cross_entropy_loss(hidden_dim, y) diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index 306786309921..320d026bdd83 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -33,9 +33,10 @@ def _test_lamb_fp32_grad_clip(args, model, hidden_dim): data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, - device=model.device) + device=model.device, + dtype=torch.float) for n, batch in enumerate(data_loader): - loss = model(batch[0].float(), batch[1]) + loss = model(batch[0], batch[1]) model.backward(loss) model.step() @@ -81,7 +82,7 @@ def _test_lamb_fp16_basic(args, model, hidden_dim): def test_lamb_fp16_empty_grad(tmpdir): config_dict = { - "train_batch_size": 1, + "train_batch_size": 2, "steps_per_print": 1, "optimizer": { "type": "Lamb", @@ -97,9 +98,9 @@ def test_lamb_fp16_empty_grad(tmpdir): args = args_from_dict(tmpdir, config_dict) hidden_dim = 10 - model = SimpleModel(hidden_dim, empty_grad=True) + model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank) - @distributed_test(world_size=[1]) + @distributed_test(world_size=[2]) def _test_lamb_fp16_empty_grad(args, model, hidden_dim): model, _, _,_ = deepspeed.initialize(args=args, model=model, @@ -116,6 +117,44 @@ def _test_lamb_fp16_empty_grad(args, model, hidden_dim): _test_lamb_fp16_empty_grad(args=args, model=model, hidden_dim=hidden_dim) +def test_adam_fp32_empty_grad(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": False + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=True, rank=args.local_rank) + + @distributed_test(world_size=[2]) + def _test_adam_fp32_empty_grad(args, model, hidden_dim): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_adam_fp32_empty_grad(args=args, model=model, hidden_dim=hidden_dim) + + def test_adamw_fp16_basic(tmpdir): config_dict = { "train_batch_size": 1, @@ -495,3 +534,41 @@ def _test_adam_amp_o2(args, model, hidden_dim): model.step() _test_adam_amp_o2(args=args, model=model, hidden_dim=hidden_dim) + + +def test_adam_amp_o2_empty_grad(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "amp": { + "enabled": True, + "opt_level": "O2" + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False, rank=args.local_rank) + + @distributed_test(world_size=[2]) + def _test_adam_amp_o2_empty_grad(args, model, hidden_dim): + model, _, _,_ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_adam_amp_o2_empty_grad(args=args, model=model, hidden_dim=hidden_dim)