From 9f34f775afba46d7f7e90f0639ed131b05fea39f Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Wed, 27 Jul 2022 17:06:33 -0700 Subject: [PATCH 1/2] shutdown zero.Init from within ds.init --- deepspeed/__init__.py | 3 + .../runtime/zero/partition_parameters.py | 58 ++++++++++++------- tests/unit/test_zero_context.py | 27 +++++++++ 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 7a18f98a49e8..51fa42e6b819 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -115,6 +115,9 @@ def initialize(args=None, __git_branch__), ranks=[0]) + # Disable zero.Init context if it's currently enabled + zero.partition_parameters.shutdown_init_context() + assert model is not None, "deepspeed.initialize requires a model" if not isinstance(model, PipelineModule): diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index bab0e95400d8..bc84fa90c144 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -35,6 +35,7 @@ param_count = 0 partitioned_param_data_shape = [0] +zero_init_enabled = False def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None): @@ -252,10 +253,12 @@ def __init__(self, mem_efficient_linear=True, ds_config=None, dtype=None): + global zero_init_enabled self.mem_efficient_linear = mem_efficient_linear self.enabled = enabled self._set_dtype(ds_config, dtype) assert self.dtype in [torch.half, torch.bfloat16, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" + zero_init_enabled = True def __enter__(self): if not self.enabled: @@ -404,28 +407,7 @@ def __exit__(self, exc_type, exc_value, traceback): if not self.enabled: return - def _disable_class(cls): - cls.__init__ = cls._old_init - - # Replace .__init__() for all existing subclasses of torch.nn.Module - for subclass in get_all_subclasses(torch.nn.modules.module.Module): - _disable_class(subclass) - - # putting methods back the way we found them - torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass - torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply - - torch.Tensor.__new__ = torch.Tensor.__old_new__ - torch.empty = _orig_torch_empty - torch.zeros = _orig_torch_zeros - torch.ones = _orig_torch_ones - torch.full = _orig_torch_full - - # un doing it here will undo it during training - # if self.mem_efficient_linear: - # torch.nn.functional.linear = self.linear_bk - # if self.mem_efficient_linear: - # torch.nn.functional.linear = self.linear_bk + shutdown_init_context() if dist.get_rank() == 0: logger.info("finished initializing model with %.2fB parameters", @@ -454,6 +436,38 @@ def _set_dtype(self, ds_config, dtype): self.dtype = dtype or torch.half +def shutdown_init_context(): + global zero_init_enabled + + if not zero_init_enabled: + return + + def _disable_class(cls): + cls.__init__ = cls._old_init + + # Replace .__init__() for all existing subclasses of torch.nn.Module + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class(subclass) + + # putting methods back the way we found them + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + + torch.Tensor.__new__ = torch.Tensor.__old_new__ + torch.empty = _orig_torch_empty + torch.zeros = _orig_torch_zeros + torch.ones = _orig_torch_ones + torch.full = _orig_torch_full + + # un doing it here will undo it during training + # if self.mem_efficient_linear: + # torch.nn.functional.linear = self.linear_bk + # if self.mem_efficient_linear: + # torch.nn.functional.linear = self.linear_bk + + zero_init_enabled = False + + class AllGatherHandle: def __init__(self, handle, param: Parameter) -> None: if param.ds_status != ZeroParamStatus.INFLIGHT: diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index e689005709d9..a8fb31a8c8e5 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -360,3 +360,30 @@ def test_subclass_param_init(): assert torch.equal(model.param, ones + 1) assert torch.equal(model.param_pa, ones + 2) assert torch.equal(model.param_grandpa, ones + 3) + + +@distributed_test(world_size=2) +def test_ds_init_w_zinit(): + ds_config = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + } + } + + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + self.linear = torch.nn.Linear(4, 4) + + def magic(self): + return 42 + + with deepspeed.zero.Init(): + model = Model() + engine, *_ = deepspeed.initialize(model=model, config=ds_config, model_parameters=model.parameters()) + assert engine.magic() == 42 From 5f537b14e31864203d91da055d3f2d3caa92aba4 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Fri, 29 Jul 2022 16:34:15 -0700 Subject: [PATCH 2/2] set zero_init_enabled only at enter --- deepspeed/runtime/zero/partition_parameters.py | 4 ++-- op_builder/transformer_inference.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index bc84fa90c144..db135e0cf182 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -253,16 +253,16 @@ def __init__(self, mem_efficient_linear=True, ds_config=None, dtype=None): - global zero_init_enabled self.mem_efficient_linear = mem_efficient_linear self.enabled = enabled self._set_dtype(ds_config, dtype) assert self.dtype in [torch.half, torch.bfloat16, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" - zero_init_enabled = True def __enter__(self): + global zero_init_enabled if not self.enabled: return + zero_init_enabled = True def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: """many models make use of child modules like Linear or Embedding which diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index e9df633174f3..42e909aeb6a0 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -1,5 +1,4 @@ import torch -from packaging import version from .builder import CUDAOpBuilder, installed_cuda_version @@ -18,7 +17,7 @@ def is_compatible(self, verbose=True): cuda_okay = True if not self.is_rocm_pytorch() and torch.cuda.is_available(): sys_cuda_major, _ = installed_cuda_version() - torch_cuda_major = version.parse(torch.version.cuda).major + torch_cuda_major = int(torch.version.cuda.split('.')[0]) cuda_capability = torch.cuda.get_device_properties(0).major if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: