Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def initialize(args=None,
__git_branch__),
ranks=[0])

# Disable zero.Init context if it's currently enabled
zero.partition_parameters.shutdown_init_context()

assert model is not None, "deepspeed.initialize requires a model"

if not isinstance(model, PipelineModule):
Expand Down
58 changes: 36 additions & 22 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

param_count = 0
partitioned_param_data_shape = [0]
zero_init_enabled = False


def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None):
Expand Down Expand Up @@ -258,8 +259,10 @@ def __init__(self,
assert self.dtype in [torch.half, torch.bfloat16, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]"

def __enter__(self):
global zero_init_enabled
if not self.enabled:
return
zero_init_enabled = True

def apply_with_gather(orig_module_apply_fn: Callable) -> Callable:
"""many models make use of child modules like Linear or Embedding which
Expand Down Expand Up @@ -404,28 +407,7 @@ def __exit__(self, exc_type, exc_value, traceback):
if not self.enabled:
return

def _disable_class(cls):
cls.__init__ = cls._old_init

# Replace .__init__() for all existing subclasses of torch.nn.Module
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class(subclass)

# putting methods back the way we found them
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply

torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.empty = _orig_torch_empty
torch.zeros = _orig_torch_zeros
torch.ones = _orig_torch_ones
torch.full = _orig_torch_full

# un doing it here will undo it during training
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
shutdown_init_context()

if dist.get_rank() == 0:
logger.info("finished initializing model with %.2fB parameters",
Expand Down Expand Up @@ -454,6 +436,38 @@ def _set_dtype(self, ds_config, dtype):
self.dtype = dtype or torch.half


def shutdown_init_context():
global zero_init_enabled

if not zero_init_enabled:
return

def _disable_class(cls):
cls.__init__ = cls._old_init

# Replace .__init__() for all existing subclasses of torch.nn.Module
for subclass in get_all_subclasses(torch.nn.modules.module.Module):
_disable_class(subclass)

# putting methods back the way we found them
torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply

torch.Tensor.__new__ = torch.Tensor.__old_new__
torch.empty = _orig_torch_empty
torch.zeros = _orig_torch_zeros
torch.ones = _orig_torch_ones
torch.full = _orig_torch_full

# un doing it here will undo it during training
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk
# if self.mem_efficient_linear:
# torch.nn.functional.linear = self.linear_bk

zero_init_enabled = False


class AllGatherHandle:
def __init__(self, handle, param: Parameter) -> None:
if param.ds_status != ZeroParamStatus.INFLIGHT:
Expand Down
3 changes: 1 addition & 2 deletions op_builder/transformer_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from packaging import version
from .builder import CUDAOpBuilder, installed_cuda_version


Expand All @@ -18,7 +17,7 @@ def is_compatible(self, verbose=True):
cuda_okay = True
if not self.is_rocm_pytorch() and torch.cuda.is_available():
sys_cuda_major, _ = installed_cuda_version()
torch_cuda_major = version.parse(torch.version.cuda).major
torch_cuda_major = int(torch.version.cuda.split('.')[0])
cuda_capability = torch.cuda.get_device_properties(0).major
if cuda_capability >= 8:
if torch_cuda_major < 11 or sys_cuda_major < 11:
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_zero_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,30 @@ def test_subclass_param_init():
assert torch.equal(model.param, ones + 1)
assert torch.equal(model.param_pa, ones + 2)
assert torch.equal(model.param_grandpa, ones + 3)


@distributed_test(world_size=2)
def test_ds_init_w_zinit():
ds_config = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
}
}

class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)

def magic(self):
return 42

with deepspeed.zero.Init():
model = Model()
engine, *_ = deepspeed.initialize(model=model, config=ds_config, model_parameters=model.parameters())
assert engine.magic() == 42