Skip to content
16 changes: 12 additions & 4 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from deepspeed.runtime.zero.partition_parameters import _init_external_params
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params
from deepspeed import comm as dist

FWD_MODULE_STACK = list()

Expand All @@ -21,6 +22,10 @@ def is_builtin_type(obj):
return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"


# ensure we only warn once, otherwise every iteration will trigger a warning
warned = False


#apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if isinstance(outputs, (tuple, list)):
Expand All @@ -45,10 +50,13 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
return functional.apply(module, backward_function, outputs)
else:
if not is_builtin_type(outputs):
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
global warned
if not warned and dist.get_rank() == 0:
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
warned = True
return outputs


Expand Down