From d6d1bfe6f146dfba141e23cc118bdc6a26dc49db Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 2 Apr 2024 10:55:22 -0700 Subject: [PATCH 1/4] ensure capacity does not exceed number of tokens --- deepspeed/moe/sharded_moe.py | 4 +++- tests/unit/moe/test_moe.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/deepspeed/moe/sharded_moe.py b/deepspeed/moe/sharded_moe.py index d6c023ec11d3..e685a0f574f3 100644 --- a/deepspeed/moe/sharded_moe.py +++ b/deepspeed/moe/sharded_moe.py @@ -209,13 +209,15 @@ def top1gating(logits: Tensor, # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) + # Communicate across all processes to pick the maximum capacity. dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) if groups._get_expert_model_parallel_world_size() == 1: # If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'. # This is since we are going to activate drop_tokens() to drop duplicate tokens. tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size() new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) - capacity = new_capacity + # Make sure the capacity value does not exceed the number of tokens. + capacity = min(new_capacity, torch.tensor(mask1.size(0))) # Compute l_aux me = torch.mean(gates, dim=0) diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 310a0df16381..f0544f513334 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -9,6 +9,8 @@ import gc from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader +import deepspeed.comm as dist +from deepspeed.moe.sharded_moe import top1gating from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from deepspeed.runtime.utils import required_torch_version @@ -132,3 +134,25 @@ def test(self, ep_size, use_residual): loss = model(batch[0], batch[1]) model.backward(loss) model.step() + + +class TestTopk(DistributedTest): + world_size = 2 + + def test(self): + if dist.get_rank() == 0: + logits = torch.tensor([[0.8903, 0.0275], [0.9031, 0.5386]], device='cuda:0') + elif dist.get_rank() == 1: + logits = torch.tensor( + [[0.8903, 0.0275], [0.9031, 0.5386], [0.7312, 0.9047], [0.3370, 0.0347], [0.6334, 0.0201], + [0.9307, 0.5607], [0.1691, 0.5992], [0.6501, 0.3025], [0.7642, 0.5446], [0.1114, 0.6924]], + device='cuda:1') + + output = top1gating(logits=logits, + capacity_factor=1, + min_capacity=0, + used_token=None, + noisy_gate_policy=None, + drop_tokens=False, + use_rts=True, + use_tutel=False) From c96e35007b0e28fc946436fcbdb2c8456ca823ed Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 2 Apr 2024 11:16:44 -0700 Subject: [PATCH 2/4] auto convert moe param groups --- deepspeed/moe/utils.py | 26 ++++++++++++++++++++++++++ deepspeed/runtime/engine.py | 4 +++- tests/unit/moe/test_moe.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/deepspeed/moe/utils.py b/deepspeed/moe/utils.py index f52fe2e3442d..9564a0ea0fb8 100644 --- a/deepspeed/moe/utils.py +++ b/deepspeed/moe/utils.py @@ -150,3 +150,29 @@ def split_params_into_different_moe_groups_for_optimizer( def is_moe_param_group(param_group): return param_group.get('moe', False) + + +def configure_moe_param_groups(model_parameters: List): + # peak at the first element to determine how to proceed + first = model_parameters[0] + + # match torch.optim.Optimizer expectations + if not isinstance(first, (torch.Tensor, dict)): + raise TypeError("param argument that would be given to the optimizer should be " + f"an iterable of Tensors or dicts, but got {type(first)}") + + # Case 1: model_parameters is a list of torch.nn.Parameter + # -> need to create moe compatible param groups + if isinstance(first, torch.nn.Parameter): + param_group = {'params': model_parameters, 'name': 'dense-params'} + return split_params_into_different_moe_groups_for_optimizer(param_group) + + # Case 2: model_parameters is a list of param groups List[dict] + # -> moe compatible param groups might already exist, if not create them + elif isinstance(first, dict): + #there are no moe groups created + if not any(['moe' in param_group for param_group in model_parameters]): + return split_params_into_different_moe_groups_for_optimizer(model_parameters) + else: + # moe groups exist, nothing to do + return model_parameters diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 3ad37baeedcb..31bb58b64e04 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -94,7 +94,7 @@ from ..ops.adam import FusedAdam from ..moe.sharded_moe import TopKGate, MOELayer from ..moe.layer import MoE -from ..moe.utils import is_moe_param +from ..moe.utils import is_moe_param, configure_moe_param_groups from ..git_version_info import version from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler @@ -1227,6 +1227,8 @@ def _do_optimizer_sanity_check(self, basic_optimizer): # Configure optimizer def _configure_optimizer(self, client_optimizer, model_parameters): if client_optimizer is None: + if self.has_moe_layers: + model_parameters = configure_moe_param_groups(model_parameters) basic_optimizer = self._configure_basic_optimizer(model_parameters) log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0]) else: diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index f0544f513334..5c27207fabe2 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -15,6 +15,42 @@ from deepspeed.runtime.utils import required_torch_version +@pytest.mark.parametrize("zero_stage", [0, 1, 2]) +class TestSimpleMoE(DistributedTest): + world_size = 1 + + def test(self, zero_stage): + if not required_torch_version(min_version=1.8): + pytest.skip("DeepSpeed MoE tests need torch 1.8 or higher to run correctly") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": zero_stage + } + } + # should automatically create moe param groups in deepspeed backend + hidden_dim = 16 + model = SimpleMoEModel(hidden_dim=hidden_dim, ep_size=1) + model, optimizer, _, _ = deepspeed.initialize(config=config_dict, model=model) + data_loader = sequence_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) + + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + @pytest.mark.parametrize("ep_size", [2, 4]) @pytest.mark.parametrize("zero_stage", [0, 1, 2]) @pytest.mark.parametrize("use_residual", [True, False]) From e942e76b3bfd9bbdaade35a0150564c12d6ed268 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 2 Apr 2024 11:27:14 -0700 Subject: [PATCH 3/4] remove unrelated test --- tests/unit/moe/test_moe.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 5c27207fabe2..4878a4888fcc 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -9,8 +9,6 @@ import gc from unit.common import DistributedTest from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader -import deepspeed.comm as dist -from deepspeed.moe.sharded_moe import top1gating from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param from deepspeed.runtime.utils import required_torch_version @@ -170,25 +168,3 @@ def test(self, ep_size, use_residual): loss = model(batch[0], batch[1]) model.backward(loss) model.step() - - -class TestTopk(DistributedTest): - world_size = 2 - - def test(self): - if dist.get_rank() == 0: - logits = torch.tensor([[0.8903, 0.0275], [0.9031, 0.5386]], device='cuda:0') - elif dist.get_rank() == 1: - logits = torch.tensor( - [[0.8903, 0.0275], [0.9031, 0.5386], [0.7312, 0.9047], [0.3370, 0.0347], [0.6334, 0.0201], - [0.9307, 0.5607], [0.1691, 0.5992], [0.6501, 0.3025], [0.7642, 0.5446], [0.1114, 0.6924]], - device='cuda:1') - - output = top1gating(logits=logits, - capacity_factor=1, - min_capacity=0, - used_token=None, - noisy_gate_policy=None, - drop_tokens=False, - use_rts=True, - use_tutel=False) From 2c42ef118d504747aceac844b591b9337b54905a Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 2 Apr 2024 15:19:24 -0700 Subject: [PATCH 4/4] Update test_moe.py --- tests/unit/moe/test_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index 4878a4888fcc..648423405922 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("zero_stage", [0, 1, 2]) class TestSimpleMoE(DistributedTest): - world_size = 1 + world_size = 2 def test(self, zero_stage): if not required_torch_version(min_version=1.8):