Skip to content
77 changes: 58 additions & 19 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,19 +864,20 @@ def _partition_gradient(self, param, partition_buffer=None, accumulate=False):


class GatheredParameters:
def __init__(self, param, modifier_rank=None, fwd_module=None, enabled=True):
"""A context that collects a parameter that was partitioned via a
:class:`deepspeed.zero.Init` context. The parameter is partitioned
def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True):
"""A context that collects parameters that were partitioned via a
:class:`deepspeed.zero.Init` context. The parameters are partitioned
again upon exit.

Args:
param (``torch.nn.Parameter``): The parameter to collect.
params (``torch.nn.Parameter``): A single parameter or a list of parameters to collect.
It's assumed that all parameters are zero params.
modifier_rank (int, optional): If specified, this rank's parameter will be
broadcasted after the context. This argument is required if ``param`` is
modified all processes should have a consistent view of the data. Defaults
broadcasted on exit from the context. This argument is required if ``params`` are
modified, so that all processes have a consistent view of the data. Defaults
to ``None``.
fwd_module (``torch.nn.Module``, optional): If specified, ``param`` will be
registered as an external parameter of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be
registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`.
enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``.

Examples
Expand Down Expand Up @@ -911,41 +912,79 @@ def forward(self, input):
fwd_module=self):
y = self.layer2(x, self.layer1.weight)
return y


#. Pretrained model loading

.. code-block:: python

with deepspeed.zero.Init():
model = MyModel()

state_dict = torch.load(model_path, map_location="cpu")

def load(module: nn.Module, prefix=""):
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(state_dict, prefix)

for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")

load(model, prefix="")

If this approach is not used, then the full model will first get copied to each GPU. For models
bigger than the memory of a single gpu this method is required.
"""

self.enabled = enabled
if not enabled:
return

# This is a no-op, just return.
if not is_zero_param(param):
if not isinstance(params, list):
params = [params]

# enable if at least one is zero-param, otherwise a noop
if not any(is_zero_param(p) for p in params):
self.enabled = False
return

self.param = param
self.params = params
self.src_rank = None
if modifier_rank is not None:
if self.param.ds_process_group == torch.distributed.group.WORLD:
if self.params[0].ds_process_group == torch.distributed.group.WORLD:
self.src_rank = modifier_rank
else:
# A group was specified; convert DP rank to global rank
self.src_rank = _get_global_rank(self.param.ds_process_group,
self.src_rank = _get_global_rank(self.params[0].ds_process_group,
modifier_rank)
self.fwd_module = fwd_module
if self.fwd_module is not None:
# is a no-op if already registered
register_external_parameter(self.fwd_module, self.param)
for p in self.params:
register_external_parameter(self.fwd_module, p)

def __enter__(self):
if not self.enabled:
return
self.param.all_gather()
self.params[0].all_gather(param_list=self.params)

def __exit__(self, *exc):
if not self.enabled:
return
if self.src_rank is not None:
torch.distributed.broadcast(self.param,
if self.src_rank is None:
return

handles = [
torch.distributed.broadcast(p,
self.src_rank,
group=self.param.ds_process_group)
self.param.partition(has_been_updated=self.src_rank is not None)
group=p.ds_process_group,
async_op=True) for p in self.params
]
for h in handles:
h.wait()
self.params[0].partition(param_list=self.params, has_been_updated=True)