diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b912d0e4cb87..892394e049ad 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -25,9 +25,12 @@ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload + from deepspeed.runtime.activation_checkpointing import ( checkpointing as activation_checkpointing, ) + from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer from deepspeed.runtime.bf16_optimizer import BF16_Optimizer @@ -328,7 +331,8 @@ def __init__( self.save_non_zero_checkpoint = False self.save_zero_checkpoint = False - self._configure_checkpointing(dist_init_required) + if not isinstance(self.optimizer, DeepSpeedZeRoOffload): + self._configure_checkpointing(dist_init_required) if self.eigenvalue_enabled(): self.eigenvalue = self._configure_eigenvalue() @@ -1337,7 +1341,6 @@ def _configure_zero_optimizer(self, optimizer): "Pipeline parallelism does not support overlapped communication, will be disabled." ) overlap_comm = False - optimizer = DeepSpeedZeroOptimizer( optimizer, timers=timers, @@ -1374,33 +1377,47 @@ def _configure_zero_optimizer(self, optimizer): logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 - optimizer = DeepSpeedZeroOptimizer_Stage3( - self.module, - optimizer, - timers=timers, - ds_config=self.config, - static_loss_scale=self.loss_scale(), - dynamic_loss_scale=self.dynamic_loss_scale(), - dynamic_loss_args=self.dynamic_loss_scale_args(), - clip_grad=self.gradient_clipping(), - contiguous_gradients=self.zero_contiguous_gradients(), - reduce_bucket_size=self.zero_reduce_bucket_size(), - prefetch_bucket_size=self.zero_prefetch_bucket_size(), - max_reuse_distance=self.zero_max_reuse_distance(), - max_live_parameters=self.zero_max_live_parameters(), - param_persistence_threshold=self.zero_param_persistence_threshold(), - dp_process_group=self.data_parallel_group, - reduce_scatter=self.zero_reduce_scatter(), - overlap_comm=self.zero_overlap_comm(), - offload_optimizer_config=self.zero_offload_optimizer(), - offload_param_config=self.zero_offload_param(), - sub_group_size=self.zero_sub_group_size(), - mpu=self.mpu, - postscale_gradients=self.postscale_gradients(), - gradient_predivide_factor=self.gradient_predivide_factor(), - gradient_accumulation_steps=self.gradient_accumulation_steps(), - aio_config=self.aio_config(), - communication_data_type=self.communication_data_type) + if isinstance(optimizer, DummyOptim): + optimizer = DeepSpeedZeRoOffload( + self.module, + timers=timers, + ds_config=self.config, + overlap_comm=self.zero_overlap_comm(), + prefetch_bucket_size=self.zero_prefetch_bucket_size(), + max_reuse_distance=self.zero_max_reuse_distance(), + max_live_parameters=self.zero_max_live_parameters(), + param_persistence_threshold=self.zero_param_persistence_threshold(), + offload_param_config=self.zero_offload_param(), + mpu=self.mpu) + else: + + optimizer = DeepSpeedZeroOptimizer_Stage3( + self.module, + optimizer, + timers=timers, + ds_config=self.config, + static_loss_scale=self.loss_scale(), + dynamic_loss_scale=self.dynamic_loss_scale(), + dynamic_loss_args=self.dynamic_loss_scale_args(), + clip_grad=self.gradient_clipping(), + contiguous_gradients=self.zero_contiguous_gradients(), + reduce_bucket_size=self.zero_reduce_bucket_size(), + prefetch_bucket_size=self.zero_prefetch_bucket_size(), + max_reuse_distance=self.zero_max_reuse_distance(), + max_live_parameters=self.zero_max_live_parameters(), + param_persistence_threshold=self.zero_param_persistence_threshold(), + dp_process_group=self.data_parallel_group, + reduce_scatter=self.zero_reduce_scatter(), + overlap_comm=self.zero_overlap_comm(), + offload_optimizer_config=self.zero_offload_optimizer(), + offload_param_config=self.zero_offload_param(), + sub_group_size=self.zero_sub_group_size(), + mpu=self.mpu, + postscale_gradients=self.postscale_gradients(), + gradient_predivide_factor=self.gradient_predivide_factor(), + gradient_accumulation_steps=self.gradient_accumulation_steps(), + aio_config=self.aio_config(), + communication_data_type=self.communication_data_type) else: raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py new file mode 100644 index 000000000000..688b81900e36 --- /dev/null +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -0,0 +1,485 @@ +""" +"Copyright 2022 The Microsoft DeepSpeed Team. +Licensed under the MIT license. +""" + +import torch +from torch.cuda import Stream +from collections import OrderedDict +from deepspeed.runtime.utils import see_memory_usage +from deepspeed.runtime.zero.partition_parameters import _init_external_params +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.offload_constants import * +from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params + +FWD_MODULE_STACK = list() + + +def is_builtin_type(obj): + # https://stackoverflow.com/a/17795199 + return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins" + + +#apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if isinstance(outputs, (tuple, list)): + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, + functional, + backward_function, + output) + touched_outputs.append(touched_output) + return outputs.__class__(touched_outputs) + elif isinstance(outputs, dict): + # apply inplace to avoid recreating dict inherited objects + for key in outputs.keys(): + outputs[key] = _apply_to_tensors_only(module, + functional, + backward_function, + outputs[key]) + return outputs + + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + if not is_builtin_type(outputs): + logger.warning( + f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. " + "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " + "output tensors and therefore may not get triggered properly.") + return outputs + + +#for each tensor in outputs run the forward_function and register backward_function as hook +def _apply_forward_and_backward_to_tensors_only(module, + forward_function, + backward_function, + outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_forward_and_backward_to_tensors_only( + module, + forward_function, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + forward_function(outputs) + if outputs.requires_grad: + outputs.register_hook(backward_function) + return outputs + else: + return outputs + + +class ZeROOrderedDict(OrderedDict): + def __init__(self, parent_module, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + self._in_forward = False + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if self._parent_module._parameters._in_forward: + register_external_parameter(FWD_MODULE_STACK[-1], param) + param.all_gather() + print_rank_0( + f'Registering external parameter from getter {key} ds_id = {param.ds_id}', + force=False) + + return param + + +def _inject_parameters(module, cls): + for module in module.modules(): + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + +class PreBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + if not hasattr(module, "applied_pre_backward_ref_cnt"): + module.applied_pre_backward_ref_cnt = 0 + module.applied_pre_backward_ref_cnt += 1 + #print(f"After Forward: {ctx.module.__class__.__name__}") + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + #print(f"Before Backward: {ctx.module.__class__.__name__}") + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.pre_backward_function = pre_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.pre_backward_function(ctx.module) + #print(f"After Backward: {ctx.module.__class__.__name__}") + return (None, None) + args + + +class DeepSpeedZeRoOffload(object): + def __init__(self, + module, + timers, + ds_config, + overlap_comm=True, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + offload_param_config=None, + mpu=None): + + see_memory_usage("TensorOffload initialize beginning", force=True) + + print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", + force=False) + + self.module = module + self.dtype = list(module.parameters())[0].dtype + self.offload_device = None + self.offload_param_pin_memory = False + if offload_param_config is not None: + self.offload_device = offload_param_config[OFFLOAD_PARAM_DEVICE] + self.offload_param_pin_memory = offload_param_config[ + OFFLOAD_PARAM_PIN_MEMORY] + + self._convert_to_zero_parameters(ds_config, module, mpu) + + for m in module.modules(): + _init_external_params(m) + + _inject_parameters(module, ZeROOrderedDict) + + self.persistence_threshold = int(param_persistence_threshold) + self.persistent_parameters = self.mark_persistent_parameters() + + self.param_coordinators = {} + self._prefetch_bucket_sz = int(prefetch_bucket_size) + self._max_reuse_distance_in_numel = int(max_reuse_distance) + self._max_available_parameters_in_numel = int(max_live_parameters) + self.__allgather_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + + self.forward_hooks = [] + self.backward_hooks = [] + self.setup_zero_stage3_hooks() + print_rank_0( + f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', + force=False) + + @instrument_w_nvtx + def partition_all_parameters(self): + """Partitioning Parameters that were not partitioned usually if parameters + of modules whose input parameters do not require grad computation do not + trigger post call and will therefore will remain unpartitioned""" + self.get_param_coordinator(training=self.module.training).release_and_reset_all( + self.module) + for param in iter_params(self.module, recurse=True): + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") + + def get_param_coordinator(self, training): + if not training in self.param_coordinators: + self.param_coordinators[training] = PartitionedParameterCoordinator( + prefetch_bucket_sz=self._prefetch_bucket_sz, + max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, + max_available_parameters_in_numel=self. + _max_available_parameters_in_numel, + allgather_stream=self.__allgather_stream, + prefetch_nvme=self.offload_device == OFFLOAD_NVME_DEVICE, + ) + + return self.param_coordinators[training] + + def _convert_to_zero_parameters(self, ds_config, module, mpu): + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + + Init(module=module, + data_parallel_group=group, + dtype=self.dtype, + config_dict_or_path=ds_config, + remote_device=self.offload_device, + pin_memory=self.offload_param_pin_memory, + mpu=mpu) + + def destroy(self): + self._remove_module_hooks() + + def _remove_module_hooks(self): + num_forward_hooks = len(self.forward_hooks) + num_backward_hooks = len(self.backward_hooks) + + for hook in self.forward_hooks: + hook.remove() + + for hook in self.backward_hooks: + hook.remove() + + print_rank_0( + f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}', + force=False) + + def setup_zero_stage3_hooks(self): + self.hierarchy = 0 + + #reset step if in inference mode + @instrument_w_nvtx + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.get_param_coordinator(training=False).reset_step() + + #likely one of them should be enough but just to be safe + self._register_hooks_recursively(self.module) + self.module.register_forward_hook(_end_of_forward_hook) + + # Add top module to stack trace + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(self.module) + + def mark_persistent_parameters(self): + persistent_params = [] + total_persistent_parameters = 0 + params_count = 0 + for _, param in self.module.named_parameters(recurse=True): + if param.ds_numel < self.persistence_threshold: + params_count += 1 + param.ds_persist = True + persistent_params.append(param) + total_persistent_parameters += param.ds_numel + + print_rank_0( + f"Parameter Offload: Total persistent parameters: {total_persistent_parameters} in {params_count} params", + force=False) + + return persistent_params + + def _register_hooks_recursively(self, module, count=[0]): + my_count = count[0] + module.id = my_count + + #print(f"{module.__class__} : {module.id}") + + for child in module.children(): + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) + + @instrument_w_nvtx + def _pre_forward_module_hook(module, *args): + self.pre_sub_module_forward_function(module) + + @instrument_w_nvtx + def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK + FWD_MODULE_STACK.pop() + if output is None: + output = [] + elif not isinstance(output, (list, tuple)): + if torch.is_tensor(output): + output = [output] + else: + #print(f'got UNKNOWN type {type(output)}') + outputs = [] + output = output if isinstance(output, dict) else vars(output) + for name, val in output.items(): + if not name.startswith('__') and torch.is_tensor(val): + outputs.append(val) + output = outputs + #print(f'convert output to {output}') + + for item in filter(lambda item: is_zero_param(item), output): + if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): + item.is_external_param = True + module_to_register = FWD_MODULE_STACK[-1] + register_external_parameter(module_to_register, item) + print_rank_0( + f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {item.ds_id}.', + force=False) + + # It's possible that the parameter was already external to the completed module. If so, remove it the + # registration as it will be covered by the outer module instead. + if id(item) in module._external_params: + print_rank_0( + f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {item.ds_id}', + force=False) + unregister_external_parameter(module, item) + + item.all_gather() + + self.post_sub_module_forward_function(module) + + def _pre_backward_module_hook(module, inputs, output): + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + return _apply_to_tensors_only(module, + PreBackwardFunction, + _run_before_backward_function, + output) + + #This is an alternate to doing _post_backward_module_hook + #it uses tensor.register_hook instead of using torch.autograd.Function + def _alternate_post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + #print(f"Before Forward {module.__class__.__name__}") + + def _run_after_backward_hook(*unused): + module.ds_grads_remaining = module.ds_grads_remaining - 1 + if module.ds_grads_remaining == 0: + #print(f"After backward {module.__class__.__name__}") + self.post_sub_module_backward_function(module) + + def _run_before_forward_function(input): + if input.requires_grad: + module.ds_grads_remaining += 1 + + return _apply_forward_and_backward_to_tensors_only( + module, + _run_before_forward_function, + _run_after_backward_hook, + inputs) + + def _post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + return _apply_to_tensors_only(module, + PostBackwardFunction, + _run_after_backward_function, + inputs) + + # Pre forward hook + self.forward_hooks.append( + module.register_forward_pre_hook(_pre_forward_module_hook)) + + # Post forward hook + self.forward_hooks.append( + module.register_forward_hook(_post_forward_module_hook)) + + # Pre backward hook + self.backward_hooks.append( + module.register_forward_hook(_pre_backward_module_hook)) + + # post backward hook + self.backward_hooks.append( + module.register_forward_pre_hook(_post_backward_module_hook)) + + @torch.no_grad() + def pre_sub_module_forward_function(self, sub_module): + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", + force=False) + + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(sub_module) + + param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator.trace_prologue(sub_module) + if param_coordinator.is_record_trace(): + param_coordinator.record_module(sub_module) + param_coordinator.fetch_sub_module(sub_module) + + see_memory_usage( + f"Before sub module function {sub_module.__class__.__name__} after fetch", + force=False) + + @torch.no_grad() + def post_sub_module_forward_function(self, sub_module): + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator.release_sub_module(sub_module) + + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + @torch.no_grad() + def pre_sub_module_backward_function(self, sub_module): + param_coordinator = self.get_param_coordinator(training=sub_module.training) + param_coordinator.trace_prologue(sub_module) + if param_coordinator.is_record_trace(): + param_coordinator.record_module(sub_module) + param_coordinator.fetch_sub_module(sub_module) + + @torch.no_grad() + def post_sub_module_backward_function(self, sub_module): + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + self.get_param_coordinator( + training=sub_module.training).release_sub_module(sub_module) + + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 00c6411e1c35..e4afdfdbfb0c 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -24,9 +24,10 @@ from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim +from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.op_builder import UtilsBuilder @@ -41,7 +42,6 @@ # with gradient partitioning and without pg_correctness_test = False -FWD_MODULE_STACK = list() from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file @@ -74,154 +74,6 @@ def move_to_cpu(tensor_list): tensor.data = tensor.data.cpu() -def is_builtin_type(obj): - # https://stackoverflow.com/a/17795199 - return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins" - - -#apply torch.autograd.Function that calls a backward_function to tensors in output -def _apply_to_tensors_only(module, functional, backward_function, outputs): - if isinstance(outputs, (tuple, list)): - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, - functional, - backward_function, - output) - touched_outputs.append(touched_output) - return outputs.__class__(touched_outputs) - elif isinstance(outputs, dict): - # apply inplace to avoid recreating dict inherited objects - for key in outputs.keys(): - outputs[key] = _apply_to_tensors_only(module, - functional, - backward_function, - outputs[key]) - return outputs - - elif type(outputs) is torch.Tensor: - return functional.apply(module, backward_function, outputs) - else: - if not is_builtin_type(outputs): - logger.warning( - f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. " - "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " - "output tensors and therefore may not get triggered properly.") - return outputs - - -#for each tensor in outputs run the forward_function and register backward_function as hook -def _apply_forward_and_backward_to_tensors_only(module, - forward_function, - backward_function, - outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_forward_and_backward_to_tensors_only( - module, - forward_function, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - forward_function(outputs) - if outputs.requires_grad: - outputs.register_hook(backward_function) - return outputs - else: - return outputs - - -class ZeROOrderedDict(OrderedDict): - def __init__(self, parent_module, *args, **kwargs): - """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. - - Args: - parent_module (``collections.OrderedDict``): the collection to replace - """ - - super().__init__(*args, **kwargs) - self._parent_module = parent_module - self._in_forward = False - - def __getitem__(self, key): - param = super().__getitem__(key) - - # Params can be registered as None (e.g., bias) - if param is None: - return param - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - if self._parent_module._parameters._in_forward: - register_external_parameter(FWD_MODULE_STACK[-1], param) - param.all_gather() - print_rank_0( - f'Registering external parameter from getter {key} ds_id = {param.ds_id}', - force=False) - - return param - - -def _inject_parameters(module, cls): - for module in module.modules(): - if cls == ZeROOrderedDict: - new_param = cls(parent_module=module) - else: - new_param = cls() - - for key, param in module._parameters.items(): - new_param[key] = param - module._parameters = new_param - - -class PreBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.pre_backward_function = pre_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.pre_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args - - INITIAL_MICRO_STEP_ID = -1 @@ -266,7 +118,7 @@ def __init__(self, elastic_checkpoint=False, aio_config=None): - see_memory_usage("Stage 3 initialize beginning", force=False) + see_memory_usage("Stage 3 initialize beginning", force=True) print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) @@ -285,8 +137,8 @@ def __init__(self, # - master grad and unflat master weight never exist. TODO: a way to save out unflat master? if not torch.cuda.is_available: raise SystemError("Cannot use fp16 without CUDA.") + self.optimizer = init_optimizer - self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() @@ -309,20 +161,21 @@ def __init__(self, self.params_in_nvme_and_cpu = False self.max_params_in_cpu = 0 + self.parameter_offload = DeepSpeedZeRoOffload(module, + timers, + ds_config, + overlap_comm, + prefetch_bucket_size, + max_reuse_distance, + max_live_parameters, + param_persistence_threshold, + offload_param_config) + self.persistent_parameters = self.parameter_offload.persistent_parameters self._configure_offloading(offload_optimizer_config, offload_param_config) - self._convert_to_zero_parameters(ds_config, module, mpu) - - for m in module.modules(): - _init_external_params(m) - self.module = module self.elastic_checkpoint = elastic_checkpoint - # Replace ._parameters with a new class to enable auto-registration of - # external parameters - _inject_parameters(module, ZeROOrderedDict) - self.__inf_or_nan_tracker: Tensor = torch.zeros( 1, dtype=torch.bool, @@ -335,41 +188,14 @@ def __init__(self, self.device = torch.cuda.current_device( ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE ### streams used for overlapping computation with communication - self.__allgather_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() self.__reduce_and_partition_stream = Stream( ) if overlap_comm else torch.cuda.default_stream() ############################################################################ - see_memory_usage("Before Partitioned Parameter Coordinator", force=False) - self.param_coordinators = {} - self._prefetch_bucket_sz = int(prefetch_bucket_size) - self._max_reuse_distance_in_numel = int(max_reuse_distance) - self._max_available_parameters_in_numel = int(max_live_parameters) - see_memory_usage("After Partitioned Parameter Coordinator", force=False) - self.__n_caching_allocator_flushes = 0 #-------------Stage 3 Setup-------------------# - # parameters smaller than the threshold will be collectively gathered at the - # end of the optimizer step and will be kept till the end of the backward pass - # TODO maybe worth just replicating these parameters and doing all reduce for them - self.persistence_threshold = int(param_persistence_threshold) - - self.persistent_parameters = self.persistent_parameters() - - self.forward_hooks = [] - self.backward_hooks = [] - self.setup_zero_stage3_hooks() - print_rank_0( - f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', - force=False) - - #resetting ds_tensor just in case parameters have been changed after initialization - #example .half() or .to() - #self.reset_ds_tensor() - #---------------------------------------------# self.timers = timers @@ -426,6 +252,7 @@ def __init__(self, self.all_reduce_print = False self.prefetch_elements = int(prefetch_bucket_size) + self.contiguous_gradients = contiguous_gradients # padding on each partition for alignment purposes @@ -488,10 +315,9 @@ def __init__(self, f'Largest partitioned param numel = {largest_partitioned_param_numel}', force=False) + self._setup_for_real_optimizer() self.grad_position = {} - if self.using_real_optimizer: - self._setup_for_real_optimizer() - self.set_grad_positions() + self.set_grad_positions() if self.offload_optimizer: self.norm_for_param_grads = {} @@ -517,7 +343,6 @@ def __init__(self, self.dynamic_loss_scale = False self.loss_scaler = LossScaler(scale=loss_scale_value) - cur_iter = 0 else: if dynamic_loss_args is None: self.loss_scaler = DynamicLossScaler() @@ -532,21 +357,7 @@ def __init__(self, see_memory_usage(f"After initializing ZeRO optimizer", force=True) def destroy(self): - self._remove_module_hooks() - - def _remove_module_hooks(self): - num_forward_hooks = len(self.forward_hooks) - num_backward_hooks = len(self.backward_hooks) - - for hook in self.forward_hooks: - hook.remove() - - for hook in self.backward_hooks: - hook.remove() - - print_rank_0( - f'Deleted module hooks: forward = {num_forward_hooks}, backward = {num_backward_hooks}', - force=False) + self.parameter_offload.destroy() def _setup_for_real_optimizer(self): see_memory_usage("Before creating fp32 partitions", force=False) @@ -641,17 +452,7 @@ def defragment(tensors: List[Tensor]) -> Tensor: return device_buffer def _get_param_coordinator(self, training): - if not training in self.param_coordinators: - self.param_coordinators[training] = PartitionedParameterCoordinator( - prefetch_bucket_sz=self._prefetch_bucket_sz, - max_reuse_distance_in_numel=self._max_reuse_distance_in_numel, - max_available_parameters_in_numel=self. - _max_available_parameters_in_numel, - allgather_stream=self.__allgather_stream, - prefetch_nvme=self.params_in_nvme_and_cpu, - ) - - return self.param_coordinators[training] + return self.parameter_offload.get_param_coordinator(training) def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload optimizer setup ################################## @@ -666,8 +467,6 @@ def _configure_offloading(self, offload_optimizer_config, offload_param_config): ###################### offload param setup ################################## if offload_param_config is not None: - if self.using_real_optimizer: - assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" self.offload_param = True self.offload_param_pin_memory = offload_param_config[ OFFLOAD_PARAM_PIN_MEMORY] @@ -678,32 +477,6 @@ def _configure_offloading(self, offload_optimizer_config, offload_param_config): f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", force=False) - def _convert_to_zero_parameters(self, ds_config, module, mpu): - non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] - if non_zero_params: - zero_params = [p for p in module.parameters() if is_zero_param(p)] - if zero_params: - zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) - else: - group = None - if mpu: - group = mpu.get_data_parallel_group() - - if self.params_in_nvme_and_cpu: - remote_device = OFFLOAD_NVME_DEVICE - elif self.offload_param: - remote_device = OFFLOAD_CPU_DEVICE - else: - remote_device = None - - Init(module=module, - data_parallel_group=group, - dtype=self.dtype, - config_dict_or_path=ds_config, - remote_device=remote_device, - pin_memory=self.offload_param_pin_memory, - mpu=mpu) - def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): nvme_swap_folder = os.path.join( offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], @@ -1075,221 +848,6 @@ def _create_fp16_sub_groups(self, params_group): return sub_groups - # def reset_ds_tensor(self): - # for name, param in self.module.named_parameters(recurse=True): - # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" - # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" - # param.ds_tensor.data = param.data - - def setup_zero_stage3_hooks(self): - self.hierarchy = 0 - - #reset step if in inference mode - @instrument_w_nvtx - def _end_of_forward_hook(module, *args): - - if not torch._C.is_grad_enabled(): - self._get_param_coordinator(training=False).reset_step() - - #likely one of them should be enough but just to be safe - self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) - - # Add top module to stack trace - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(self.module) - - def persistent_parameters(self): - persistent_params = [] - total_persistent_parameters = 0 - params_count = 0 - for _, param in self.module.named_parameters(recurse=True): - if param.ds_numel < self.persistence_threshold: - params_count += 1 - param.ds_persist = True - persistent_params.append(param) - total_persistent_parameters += param.ds_numel - - print_rank_0( - f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", - force=False) - return persistent_params - - def _register_hooks_recursively(self, module, count=[0]): - my_count = count[0] - module.id = my_count - - #print(f"{module.__class__} : {module.id}") - - for child in module.children(): - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) - - @instrument_w_nvtx - def _pre_forward_module_hook(module, *args): - self.pre_sub_module_forward_function(module) - - @instrument_w_nvtx - def _post_forward_module_hook(module, input, output): - global FWD_MODULE_STACK - FWD_MODULE_STACK.pop() - if output is None: - output = [] - elif not isinstance(output, (list, tuple)): - if torch.is_tensor(output): - output = [output] - else: - #print(f'got UNKNOWN type {type(output)}') - outputs = [] - output = output if isinstance(output, dict) else vars(output) - for name, val in output.items(): - if not name.startswith('__') and torch.is_tensor(val): - outputs.append(val) - output = outputs - #print(f'convert output to {output}') - - for item in filter(lambda item: is_zero_param(item), output): - if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): - item.is_external_param = True - module_to_register = FWD_MODULE_STACK[-1] - register_external_parameter(module_to_register, item) - print_rank_0( - f'Registering dangling parameter for module {module_to_register.__class__.__name__}, ds_id = {item.ds_id}.', - force=False) - - # It's possible that the parameter was already external to the completed module. If so, remove it the - # registration as it will be covered by the outer module instead. - if id(item) in module._external_params: - print_rank_0( - f' Unregistering nested dangling parameter from module {module.__class__.__name__}, ds_id = {item.ds_id}', - force=False) - unregister_external_parameter(module, item) - - item.all_gather() - - self.post_sub_module_forward_function(module) - - def _pre_backward_module_hook(module, inputs, output): - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - - return _apply_to_tensors_only(module, - PreBackwardFunction, - _run_before_backward_function, - output) - - #This is an alternate to doing _post_backward_module_hook - #it uses tensor.register_hook instead of using torch.autograd.Function - def _alternate_post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - #print(f"Before Forward {module.__class__.__name__}") - - def _run_after_backward_hook(*unused): - module.ds_grads_remaining = module.ds_grads_remaining - 1 - if module.ds_grads_remaining == 0: - #print(f"After backward {module.__class__.__name__}") - self.post_sub_module_backward_function(module) - - def _run_before_forward_function(input): - if input.requires_grad: - module.ds_grads_remaining += 1 - - return _apply_forward_and_backward_to_tensors_only( - module, - _run_before_forward_function, - _run_after_backward_hook, - inputs) - - def _post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_to_tensors_only(module, - PostBackwardFunction, - _run_after_backward_function, - inputs) - - # Pre forward hook - self.forward_hooks.append( - module.register_forward_pre_hook(_pre_forward_module_hook)) - - # Post forward hook - self.forward_hooks.append( - module.register_forward_hook(_post_forward_module_hook)) - - # Pre backward hook - self.backward_hooks.append( - module.register_forward_hook(_pre_backward_module_hook)) - - # post backward hook - self.backward_hooks.append( - module.register_forward_pre_hook(_post_backward_module_hook)) - - @torch.no_grad() - def pre_sub_module_forward_function(self, sub_module): - see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", - force=False) - - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(sub_module) - - param_coordinator = self._get_param_coordinator(training=sub_module.training) - param_coordinator.trace_prologue(sub_module) - if param_coordinator.is_record_trace(): - param_coordinator.record_module(sub_module) - param_coordinator.fetch_sub_module(sub_module) - - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after fetch", - force=False) - - @torch.no_grad() - def post_sub_module_forward_function(self, sub_module): - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - - param_coordinator = self._get_param_coordinator(training=sub_module.training) - param_coordinator.release_sub_module(sub_module) - - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - @torch.no_grad() - def pre_sub_module_backward_function(self, sub_module): - param_coordinator = self._get_param_coordinator(training=sub_module.training) - param_coordinator.trace_prologue(sub_module) - if param_coordinator.is_record_trace(): - param_coordinator.record_module(sub_module) - param_coordinator.fetch_sub_module(sub_module) - - @torch.no_grad() - def post_sub_module_backward_function(self, sub_module): - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - - self._get_param_coordinator( - training=sub_module.training).release_sub_module(sub_module) - - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - def _release_ipg_buffers(self): if self.contiguous_gradients: self.ipg_buffer = None @@ -2558,14 +2116,7 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: @instrument_w_nvtx def _partition_all_parameters(self): - """Partitioning Parameters that were not partitioned usually if parameters - of modules whose input parameters do not require grad computation do not - trigger post call and will therefore will remain unpartitioned""" - self._get_param_coordinator(training=self.module.training).release_and_reset_all( - self.module) - for param in iter_params(self.module, recurse=True): - if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError(f"{param.ds_summary()} expected to be released") + self.parameter_offload.partition_all_parameters() def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients)