diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index adeb70f9e61c..f0f0c4f56e40 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -12,6 +12,7 @@ from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params +from deepspeed import comm as dist FWD_MODULE_STACK = list() @@ -21,6 +22,10 @@ def is_builtin_type(obj): return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins" +# ensure we only warn once, otherwise every iteration will trigger a warning +warned = False + + #apply torch.autograd.Function that calls a backward_function to tensors in output def _apply_to_tensors_only(module, functional, backward_function, outputs): if isinstance(outputs, (tuple, list)): @@ -45,10 +50,13 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs): return functional.apply(module, backward_function, outputs) else: if not is_builtin_type(outputs): - logger.warning( - f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. " - "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " - "output tensors and therefore may not get triggered properly.") + global warned + if not warned and dist.get_rank() == 0: + logger.warning( + f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. " + "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " + "output tensors and therefore may not get triggered properly.") + warned = True return outputs