diff --git a/deepspeed/comm/__init__.py b/deepspeed/comm/__init__.py index 1975aab8582b..3433084324d4 100644 --- a/deepspeed/comm/__init__.py +++ b/deepspeed/comm/__init__.py @@ -22,18 +22,26 @@ def get_global_rank(group, group_rank): from torch.distributed.distributed_c10d import _get_global_rank return _get_global_rank(group, group_rank) - def allgather_fn(output_tensor, input_tensor, group, async_op): + def allgather_fn(output_tensor, input_tensor, group=None, async_op=False): from torch.distributed import all_gather, get_world_size from torch import chunk output_tensors = list(chunk(output_tensor, get_world_size(group))) - return all_gather(output_tensors, input_tensor, group=group, async_op=True) + return all_gather(output_tensors, input_tensor, group=group, async_op=async_op) - def reduce_scatter_fn(output_tensor, input_tensor, group): + def reduce_scatter_fn(output_tensor, input_tensor, group=None, async_op=False): from torch.distributed import reduce_scatter, get_world_size from torch import chunk input_tensor_lst = list(chunk(input_tensor, get_world_size(group))) return reduce_scatter(output_tensor, input_tensor_lst, group=group) + def configure(deepspeed_config=None, + enabled=None, + prof_all=None, + prof_ops=None, + verbose=None): + utils.logger.warn( + "Communication logging is not supported in torch versions older than 1.8") + else: supported_torch_version = True from .comm import * diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index eaa286a6fd26..bb88e6addcb8 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -28,7 +28,7 @@ from enum import Enum import torch import os -import torch +import time from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT, default_pg_timeout from .constants import * @@ -46,6 +46,8 @@ class ReduceOp(Enum): UNUSED = 8 +from deepspeed.utils.comms_logging import CommsLogger +from deepspeed.utils import timer, get_caller_func from deepspeed.comm.backend import Backend from deepspeed.comm.torch import TorchBackend @@ -56,6 +58,12 @@ class ReduceOp(Enum): use_ds_backend = False cdb = None +# Create global timer for ops +timers = timer.SynchronizedWallClockTimer() +timer_summary = {} + +comms_logger = CommsLogger() + # Ensure we don't warn about base collectives more than once has_warned_all_gather = False has_warned_reduce_scatter = False @@ -67,6 +75,77 @@ class ReduceOp(Enum): # This should be set here so all rank/size information from the launcher can be propagated from deepspeed.comm.utils import * + +def _configure_using_config_file(config): + if config.comms_logger_enabled: + comms_logger.configure(config) + + +def configure( + deepspeed_config=None, + enabled=None, + prof_all=None, + prof_ops=None, + verbose=None, + debug=None, +): + + if deepspeed_config is not None: + _configure_using_config_file(deepspeed_config.comms_config) + + if enabled is not None: + comms_logger.enabled = enabled + + if prof_all is not None: + comms_logger.prof_all = prof_all + + if prof_ops is not None: + comms_logger.prof_ops = prof_ops + + if verbose is not None: + comms_logger.verbose = verbose + + if debug is not None: + comms_logger.debug = debug + + +# Logging wrapper for timing ops +def timed_op(func): + def log_wrapper(*args, **kwargs): + # Add enabled flag so that overhead to each comm op is two if conditions at most + if comms_logger.enabled: + if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or ( + 'log_name' in kwargs + and kwargs['log_name'] in comms_logger.prof_ops): + # Need func args for their defaults + func_args = get_default_args(func) + func_args.update(kwargs) + msg_size = get_msg_size_from_args(func, *args, **kwargs) + log_name = get_debug_log_name(func_args, comms_logger.debug) + timers(log_name).start() + # Return the op, then stop the op's timer + try: + return func(*args, **kwargs) + finally: + if comms_logger.enabled: + # Need to make op blocking for accurate logging + torch.cuda.synchronize() + # If we're using MPI, we can't simply sync the stream + if cdb.using_mpi: + cdb.barrier() + if ('prof' in kwargs and kwargs['prof']) or comms_logger.prof_all or ( + 'log_name' in kwargs + and kwargs['log_name'] in comms_logger.prof_ops): + log_name = get_debug_log_name(func_args, comms_logger.debug) + raw_name = func.__name__ + timers(log_name).stop() + # need temp var since 'elapsed' resets events + time_elapsed = timers(log_name).elapsed(reset=False) + comms_logger.append(raw_name, log_name, time_elapsed, msg_size) + + return log_wrapper + + # For compatibility with torch distributed's init_process_group, we shall retain the signature from PyTorch code. # DeepSpeed NCCL/MPI backend may not need all these params as we will have our own implementation. # Please read full torch.distributed API docs from https://pytorch.org/docs/stable/distributed.html @@ -142,12 +221,26 @@ def set_backend(backend): print(inst) -def broadcast(tensor, src, group=None, async_op=False): +@timed_op +def broadcast(tensor, + src, + group=None, + async_op=False, + prof=False, + log_name='broadcast', + debug=get_caller_func()): global cdb return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) -def all_gather(tensor_list, tensor, group=None, async_op=False): +@timed_op +def all_gather(tensor_list, + tensor, + group=None, + async_op=False, + prof=False, + log_name='all_gather', + debug=get_caller_func()): global cdb return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, @@ -162,12 +255,22 @@ def has_reduce_scatter_base(): return cdb.has_reduce_scatter_base -def reduce_scatter_fn(output_tensor: torch.Tensor, input_tensor: torch.Tensor, group): +def reduce_scatter_fn(output_tensor, + tensor, + group=None, + async_op=False, + prof=False, + debug=get_caller_func()): global cdb global has_warned_reduce_scatter assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' if cdb.has_reduce_scatter_base: - return cdb.reduce_scatter_base(output_tensor, input_tensor, group=group) + return reduce_scatter_base(output_tensor, + tensor, + group=group, + async_op=async_op, + prof=prof, + debug=debug) else: if not has_warned_reduce_scatter: utils.logger.warning( @@ -175,15 +278,43 @@ def reduce_scatter_fn(output_tensor: torch.Tensor, input_tensor: torch.Tensor, g "torch.distributed.all_gather which will result in suboptimal performance. " "please consider upgrading your pytorch installation.") has_warned_reduce_scatter = True - input_tensor_lst = list(torch.chunk(input_tensor, cdb.get_world_size(group))) - return cdb.reduce_scatter(output_tensor, input_tensor_lst, group=group) + input_tensor_lst = list(torch.chunk(tensor, cdb.get_world_size(group))) + return reduce_scatter(output_tensor, + input_tensor_lst, + group=group, + async_op=async_op, + prof=prof, + debug=debug) + + +@timed_op +def reduce_scatter_base(output_tensor, + tensor, + group=None, + async_op=False, + prof=False, + log_name='reduce_scatter_base', + debug=get_caller_func()): + global cdb + return cdb.reduce_scatter_base(output_tensor=output_tensor, + input_tensor=tensor, + group=group, + async_op=async_op) -def reduce_scatter_base(output_tensor, input_tensor, group=None): +@timed_op +def all_gather_base(output_tensor, + tensor, + group=None, + async_op=False, + prof=False, + log_name='all_gather_base', + debug=get_caller_func()): global cdb - return cdb.reduce_scatter_base(output_tensor=output_tensor, - input_tensor=input_tensor, - group=group) + return cdb.all_gather_base(output_tensor=output_tensor, + input_tensor=tensor, + group=group, + async_op=async_op) def has_allgather_base(): @@ -193,18 +324,20 @@ def has_allgather_base(): return cdb.has_allgather_base -def allgather_fn(output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - group, - async_op): +def allgather_fn(output_tensor, + input_tensor, + group=None, + async_op=False, + debug=get_caller_func()): global cdb global has_warned_all_gather assert cdb is not None and cdb.is_initialized(), 'DeepSpeed backend not set, please initialize it using init_process_group()' if cdb.has_allgather_base: - return cdb.all_gather_base(output_tensor, - input_tensor, - group=group, - async_op=True) + return all_gather_base(output_tensor, + input_tensor, + group=group, + async_op=async_op, + debug=debug) else: if not has_warned_all_gather and get_rank() == 0: utils.logger.warning( @@ -213,55 +346,89 @@ def allgather_fn(output_tensor: torch.Tensor, "please consider upgrading your pytorch installation.") has_warned_all_gather = True output_tensors = list(torch.chunk(output_tensor, cdb.get_world_size(group))) - return cdb.all_gather(output_tensors, input_tensor, group=group, async_op=True) - - -def all_gather_base(output_tensor, input_tensor, group=None, async_op=False): - global cdb - return cdb.all_gather_base(output_tensor=output_tensor, - input_tensor=input_tensor, - group=group, - async_op=async_op) - - -def all_to_all_single( - output, - input, - output_split_sizes=None, - input_split_sizes=None, - group=None, - async_op=False, -): + return all_gather(output_tensors, + input_tensor, + group=group, + async_op=async_op, + debug=debug) + + +@timed_op +def all_to_all_single(output, + tensor, + output_split_sizes=None, + input_split_sizes=None, + group=None, + async_op=False, + prof=False, + log_name='all_to_all_single', + debug=get_caller_func()): global cdb return cdb.all_to_all_single(output=output, - input=input, + input=tensor, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group, async_op=async_op) -def send(tensor, dst, group=None, tag=0): +@timed_op +def send(tensor, + dst, + group=None, + tag=0, + prof=False, + log_name='send', + debug=get_caller_func()): global cdb return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag) -def recv(tensor, src=None, group=None, tag=0): +@timed_op +def recv(tensor, + src=None, + group=None, + tag=0, + prof=False, + log_name='recv', + debug=get_caller_func()): global cdb return cdb.recv(tensor=tensor, src=src, group=group, tag=tag) -def isend(tensor, dst, group=None, tag=0): +@timed_op +def isend(tensor, + dst, + group=None, + tag=0, + prof=False, + log_name='isend', + debug=get_caller_func()): global cdb return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag) -def irecv(tensor, src=None, group=None, tag=0): +@timed_op +def irecv(tensor, + src=None, + group=None, + tag=0, + prof=False, + log_name='irecv', + debug=get_caller_func()): global cdb return cdb.recv(tensor=tensor, src=src, group=group, tag=tag) -def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): +@timed_op +def gather(tensor, + gather_list=None, + dst=0, + group=None, + async_op=False, + prof=False, + log_name='gather', + debug=get_caller_func()): global cdb return cdb.gather(tensor=tensor, gather_list=gather_list, @@ -270,7 +437,15 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False): async_op=async_op) -def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): +@timed_op +def scatter(tensor, + scatter_list=None, + src=0, + group=None, + async_op=False, + prof=False, + log_name='scatter', + debug=get_caller_func()): global cdb return cdb.scatter(tensor=tensor, scatter_list=scatter_list, @@ -279,21 +454,42 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False): async_op=async_op) -def barrier(group=None): +@timed_op +def barrier(group=None, prof=False, log_name='barrier', debug=get_caller_func()): global cdb return cdb.barrier() -# Local enum for Reduction operators -#from .utils import ReduceOp - - -def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False): +def log_summary(): + global cdb + barrier(log_name='log_summary_barrier') + if cdb.get_rank() == 0: + comms_logger.log_all() + barrier(log_name='log_summary_barrier') + + +@timed_op +def reduce(tensor, + dst, + op=ReduceOp.SUM, + group=None, + async_op=False, + prof=False, + log_name='reduce', + debug=get_caller_func()): global cdb return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op) -def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False): +@timed_op +def reduce_scatter(output, + input_list, + op=ReduceOp.SUM, + group=None, + async_op=False, + prof=False, + log_name='reduce_scatter', + debug=get_caller_func()): global cdb return cdb.reduce_scatter(output=output, input_list=input_list, @@ -302,7 +498,14 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal async_op=async_op) -def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): +@timed_op +def all_reduce(tensor, + op=ReduceOp.SUM, + group=None, + async_op=False, + prof=False, + log_name='all_reduce', + debug=get_caller_func()): #if profile_comm: # context of the timers? # timers.start() @@ -379,7 +582,8 @@ def init_distributed(dist_backend="nccl", verbose=True, timeout=default_pg_timeout, init_method=None, - dist_init_required=None): + dist_init_required=None, + config=None): ''' Initialize dist backend, potentially performing MPI discovery if needed Arguments: @@ -389,9 +593,12 @@ def init_distributed(dist_backend="nccl", verbose: Optional (bool). verbose logging timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes. init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified. + config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling) ''' global cdb + configure(deepspeed_config=config) + if dist_init_required is None: dist_init_required = cdb is None or not cdb.is_initialized() diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py new file mode 100644 index 000000000000..2c8ceaed2df9 --- /dev/null +++ b/deepspeed/comm/config.py @@ -0,0 +1,33 @@ +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +from typing import Optional +from deepspeed.runtime.config_utils import get_scalar_param +from pydantic import BaseModel, validator, ValidationError, create_model +from .constants import * + + +class CommsConfig(BaseModel): + class Config: + validate_all = True + validate_assignment = True + use_enum_values = True + extra = 'forbid' + + +class CommsLoggerConfig(CommsConfig): + enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT + prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT + prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT + verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT + debug: bool = COMMS_LOGGER_DEBUG_DEFAULT + + +class DeepSpeedCommsConfig: + def __init__(self, ds_config): + self.comms_logger_enabled = 'comms_logger' in ds_config + + if self.comms_logger_enabled: + self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger']) diff --git a/deepspeed/comm/constants.py b/deepspeed/comm/constants.py index d85f72e8cbaa..b3a526a5afbc 100644 --- a/deepspeed/comm/constants.py +++ b/deepspeed/comm/constants.py @@ -5,3 +5,40 @@ DEFAULT_AML_MASTER_PORT = "54965" DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo" + +######################################### +# Comms Logger +######################################### +# Comms Logger. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +COMMS_LOGGER_FORMAT = ''' +The Comms Logger can be specified as: +"comms_logger": { + "enabled": true, + "verbose": false, + "prof_all": true, + "debug": false, + "prof_ops": ["all_reduce", "custom_all_reduce_name"] +} +''' +COMMS_LOGGER = "comms_logger" + +# Comms logger enable signal +COMMS_LOGGER_ENABLED = "enabled" +COMMS_LOGGER_ENABLED_DEFAULT = False + +# Comms logger verbose signal +COMMS_LOGGER_VERBOSE = "verbose" +COMMS_LOGGER_VERBOSE_DEFAULT = False + +# comms logger profile all ops signal +COMMS_LOGGER_PROF_ALL = "prof_all" +COMMS_LOGGER_PROF_ALL_DEFAULT = True + +# comms logger show all ops signal +COMMS_LOGGER_DEBUG = "debug" +COMMS_LOGGER_DEBUG_DEFAULT = False + +# comms logger profile specific ops in list +COMMS_LOGGER_PROF_OPS = "prof_ops" +COMMS_LOGGER_PROF_OPS_DEFAULT = [] diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index f7731536b696..6b4af25abd88 100644 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -39,6 +39,7 @@ def init_process_group(self, backend, timeout, init_method): torch.distributed.init_process_group(backend, timeout=timeout, init_method=init_method) + self.using_mpi = torch.distributed.get_backend() == 'mpi' def all_reduce(self, tensor, @@ -96,11 +97,16 @@ def all_gather_base(self, output_tensor, input_tensor, group=None, async_op=Fals "please consider upgrading your pytorch installation.") pass - def reduce_scatter_base(self, output_tensor, input_tensor, group=None): + def reduce_scatter_base(self, + output_tensor, + input_tensor, + group=None, + async_op=False): if self.has_reduce_scatter_base: return torch.distributed._reduce_scatter_base(output_tensor, input_tensor, - group=group) + group=group, + async_op=async_op) else: utils.logger.warning( "unable to find torch.distributed._reduce_scatter_base. will fall back to " diff --git a/deepspeed/comm/utils.py b/deepspeed/comm/utils.py index 7ca5581a2029..1af067e91750 100644 --- a/deepspeed/comm/utils.py +++ b/deepspeed/comm/utils.py @@ -1,6 +1,8 @@ import os import enum import torch +import inspect +from deepspeed.utils import get_caller_func def older_torch(): @@ -64,6 +66,7 @@ def get_world_rank_from_launcher(): def get_world_size_from_launcher(): # DeepSpeed launcher will set it so get from there size = os.environ.get('WORLD_SIZE') + rank = os.environ.get('RANK') if size is None: size = os.environ.get('OMPI_COMM_WORLD_SIZE') @@ -72,4 +75,83 @@ def get_world_size_from_launcher(): if size is None: size = 1 + if rank == 0: + print(f"set world size to {size}") + return int(size) + + +def get_default_args(func): + signature = inspect.signature(func) + return { + k: v.default + for k, + v in signature.parameters.items() if v.default is not inspect.Parameter.empty + } + + +# We need this hacky function since torch doesn't consistently name or place the input tensor args +def get_tensor_position(func): + sig_params = inspect.signature(func).parameters + arg = None + # most colls + if 'tensor' in sig_params: + arg = 'tensor' + # reduce scatter coll + elif 'input_list' in sig_params: + arg = 'input_list' + # all_to_all and torch multiGPU colls + elif 'input_tensor_list' in sig_params: + arg = 'input_tensor_list' + if arg is None: + return -1 + else: + return list(sig_params).index(arg) + + +def get_tensor_kwarg(func, kwargs): + func_args = get_default_args(func) + func_args.update(kwargs) + arg = None + + if 'tensor' in func_args: + arg = func_args['tensor'] + elif 'input_list' in func_args: + arg = func_args['input_list'] + elif 'input_tensor_list' in func_args: + arg = func_args['input_tensor_list'] + return arg + + +def get_msg_size_from_args(func, *args, **kwargs): + # 3 cases: + # - tensor arg is in args + # - tensor arg is in kwargs + # - tensor arg is not present (e.g. barrier) + tensor_arg_position = -1 + tensor_arg = None + # check if tensor arg is in args + if len(args) > 0: + tensor_arg_position = get_tensor_position(func) + if tensor_arg_position > -1: + tensor_arg = args[get_tensor_position(func)] + # check if tensor arg is in kwargs + if tensor_arg is None and len(kwargs) > 0: + tensor_arg = get_tensor_kwarg(func, kwargs) + # if tensor arg is not present, no data is being transmitted + if tensor_arg is None: + return 0 + else: + # Sum of tensor sizes for list colls such as torch's all_to_all + # NOTE: msg_size for list colls will not be the actual size transmitted by a given MPI/NCCL call within the coll op. Instead, it's the total amount of data transmitted. + if type(tensor_arg) is list: + return sum(x.element_size() * x.nelement() for x in func_args['tensor_list']) + else: + return tensor_arg.element_size() * tensor_arg.nelement() + + +def get_debug_log_name(func_args, debug): + if debug: + return func_args['log_name'] + ' | [Caller Func: ' + get_caller_func() + ']' + else: + return func_args['log_name'] diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index e92af044f53e..6a2b8e31516d 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -12,13 +12,18 @@ import torch.nn.functional from deepspeed.utils import instrument_w_nvtx -from deepspeed.utils.logging import logger +from deepspeed.utils import logger -def _torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): +def _torch_reduce_scatter_fn(input_tensor: Tensor, + output_tensor: Tensor, + group=None, + async_op=False, + prof=False): return instrument_w_nvtx(dist.reduce_scatter_fn)(output_tensor, input_tensor, - group=group) + group=group, + async_op=async_op) @instrument_w_nvtx @@ -82,7 +87,7 @@ def reduce_scatter_coalesced( # batched reduce-scatter call _torch_reduce_scatter_fn(tensor_partition_flat_buffer, tensor_partition_buffer_for_each_rank[this_rank], - group) + group=group) # reverse procedure of the interleaving done previously, done on the # result of the batched reduce-scatter diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 9e68ac60ec02..f4e627cf58d2 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -24,6 +24,7 @@ from .zero.config import DeepSpeedZeroConfig from .zero.constants import * from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig +from ..comm.config import DeepSpeedCommsConfig from ..monitor.config import DeepSpeedMonitorConfig from deepspeed import comm as dist @@ -806,6 +807,7 @@ def _initialize_params(self, param_dict): self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig( param_dict) + self.comms_config = DeepSpeedCommsConfig(param_dict) self.monitor_config = DeepSpeedMonitorConfig(param_dict) self.gradient_clipping = get_gradient_clipping(param_dict) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 07638b33033f..bf37ab71adc3 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -286,6 +286,8 @@ def __init__( self._set_distributed_vars(args) + dist.configure(self._config) + self.monitor = MonitorMaster(self._config.monitor_config) see_memory_usage( diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 94add6f9c8e4..43c65b15b525 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -14,7 +14,7 @@ import torch.optim as optim from deepspeed import comm as dist -from deepspeed.utils.logging import logger +from deepspeed.utils import logger from deepspeed.utils.timer import SynchronizedWallClockTimer, ThroughputTimer from deepspeed.inference.engine import InferenceEngine diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 62cd21b3710f..de1c819fae58 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -43,7 +43,7 @@ partitioned_param_data_shape = [0] -def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group): +def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None): return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, @@ -834,8 +834,7 @@ def all_gather_coalesced(params: Iterable[Parameter], handle = _dist_allgather_fn( param.ds_tensor.to(torch.cuda.current_device()), param_buffer, - self.ds_process_group, - ) + self.ds_process_group) param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to( diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 6f7185413d88..fb284d120011 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -21,7 +21,7 @@ from torch.nn.parameter import Parameter from deepspeed.runtime import ZeROOptimizer -from deepspeed.utils.logging import logger +from deepspeed.utils import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index fe7461002932..5e05bf46e9b6 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -1,4 +1,5 @@ from .logging import logger, log_dist +from .comms_logging import get_caller_func #from .distributed import init_distributed from .init_on_device import OnDevice from .groups import * diff --git a/deepspeed/utils/comms_logging.py b/deepspeed/utils/comms_logging.py new file mode 100644 index 000000000000..1c17868e09b1 --- /dev/null +++ b/deepspeed/utils/comms_logging.py @@ -0,0 +1,142 @@ +import logging +import sys +import os +import math +from deepspeed.utils import log_dist + + +def get_caller_func(frame=3): + import inspect, sys + return sys._getframe(frame).f_code.co_name + + +# Helper function to pretty-print message sizes +def convert_size(size_bytes): + if size_bytes == 0: + return "0B" + size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") + i = int(math.floor(math.log(size_bytes, 1024))) + p = math.pow(1024, i) + s = round(size_bytes / p, 2) + return "%s %s" % (s, size_name[i]) + + +# Helper function to calculate algbw and busbw. +# See https://gist.github.com/jeffra/b5e80466b4c86be00ea3b6f130fb7a36 and https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md +def calc_bw_log(comm_op, size, duration): + import deepspeed.comm as dist + + n = dist.get_world_size() + tput = 0 + busbw = 0 + if comm_op == "all_to_all_single": + tput = (size / duration) + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_gather" or comm_op == "all_gather_base" or comm_op == "reduce_scatter" or comm_op == "reduce_scatter_base": + size *= n + tput = (size / duration) + busbw = (size / duration) * ((n - 1) / n) + elif comm_op == "all_reduce": + tput = (size * 2 / duration) + busbw = (size / duration) * (2 * (n - 1) / n) + elif comm_op == "send" or comm_op == "recv" or comm_op == "isend" or comm_op == "irecv" or comm_op == "broadcast" or comm_op == "reduce" or comm_op == "gather" or comm_op == "scatter" or comm_op == "barrier": + tput = (size / duration) + busbw = tput + else: + print_rank_0("wrong comm_op specified") + exit(0) + + # convert to Gbps + tput *= 8 + busbw *= 8 + + tput /= 1e6 + busbw /= 1e6 + + return tput, busbw + + +class CommsLogger: + def __init__(self): + from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT + self.comms_dict = {} + self.verbose = COMMS_LOGGER_VERBOSE_DEFAULT + self.debug = COMMS_LOGGER_DEBUG_DEFAULT + self.prof_ops = COMMS_LOGGER_PROF_OPS_DEFAULT + self.prof_all = COMMS_LOGGER_PROF_ALL_DEFAULT + self.enabled = COMMS_LOGGER_ENABLED_DEFAULT + + def configure(self, comms_config): + self.enabled = comms_config.comms_logger_enabled + if self.enabled: + self.verbose = comms_config.comms_logger.verbose + self.debug = comms_config.comms_logger.debug + self.prof_ops = comms_config.comms_logger.prof_ops + self.prof_all = comms_config.comms_logger.prof_all + + # There are three settings for the op profiler: + # - Global profiling (profile all comms) + # - Op-type profiling (e.g. profile all all_reduce comms) + # - Op profiling (e.g. profile a specific all_reduce op) + def start_profiling_comms(): + self.prof_all = True + + def stop_profiling_comms(): + self.prof_all = True + + # E.g. start_profiling_op('all_reduce') + def start_profiling_op(op_name_list): + self.prof_ops = list(set(comms_logger.prof_ops) | set(op_name_list)) + + def stop_profiling_op(op_name_list): + self.prof_ops = [op for op in comms_logger.prof_ops if op not in op_name_list] + + # Add log entry + def append(self, raw_name, record_name, latency, msg_size): + import deepspeed.comm as dist + algbw, busbw = calc_bw_log(raw_name, msg_size, latency) + if record_name in self.comms_dict.keys(): + # If this comm_op has already been logged with this message size, just add to existing record + if msg_size in self.comms_dict[record_name].keys(): + self.comms_dict[record_name][msg_size][0] += 1 + self.comms_dict[record_name][msg_size][1].append(latency) + self.comms_dict[record_name][msg_size][2].append(algbw) + self.comms_dict[record_name][msg_size][3].append(busbw) + # If this is a new message size for this comm_op, add new record under existing comm_op + else: + self.comms_dict[record_name][msg_size] = [1, [latency], [algbw], [busbw]] + else: + # Create entirely new record + self.comms_dict[record_name] = {msg_size: [1, [latency], [algbw], [busbw]]} + # If verbose, print every comm op + # TODO: Add to tensorboard + if self.verbose: + n = dist.get_world_size() + log_str = f"rank={dist.get_rank()} | comm op: " + record_name + " | time (ms): {:.2f}".format( + latency) + log_str += " | msg size: " + convert_size(msg_size) + log_str += " | algbw (Gbps): {:.2f} ".format(algbw) + log_str += " | busbw (Gbps): {:.2f} ".format(busbw) + log_dist(log_str, [0]) + + # Print summary at end of iteration, epoch, or training + def log_all(self): + from deepspeed.utils.timer import trim_mean + print( + f"{'Comm. Op': <20}{'Message Size': <20}{'Count': <20}{'Total Latency(ms)': <20}{'Avg Latency(ms)': <20}{'tput_avg (Gbps)': <20}{'busbw_avg (Gbps)': <20}" + ) + for record_name in self.comms_dict.keys(): + print(record_name) + for msg_size, vals in sorted(self.comms_dict[record_name].items()): + # vals[0] is the count for each msg size + count = vals[0] + # vals[1] is a list of latency records for each msg size + total_lat = sum(vals[1]) + # vals[2] and vals[3] are the lists of algbw and busbw, respectively + # Get rid of outliers when we print + avg_lat = trim_mean(vals[1], 0.1) + avg_algbw = trim_mean(vals[2], 0.1) + avg_busbw = trim_mean(vals[3], 0.1) + print( + f"{' ': <20}{convert_size(msg_size): <20}{count: <20}{total_lat: <20.2f}{avg_lat: <20.2f}{avg_algbw: <20.2f}{avg_busbw: <20.2f}" + ) diff --git a/deepspeed/utils/logging.py b/deepspeed/utils/logging.py index 55636173b0a4..b8217dc2bb94 100644 --- a/deepspeed/utils/logging.py +++ b/deepspeed/utils/logging.py @@ -2,8 +2,6 @@ import sys import os -from deepspeed import comm as dist - log_levels = { "debug": logging.DEBUG, "info": logging.INFO, @@ -47,6 +45,7 @@ def create_logger(name=None, level=logging.INFO): def log_dist(message, ranks=None, level=logging.INFO): + from deepspeed import comm as dist """Log message when one of following condition meets + not dist.is_initialized() @@ -70,6 +69,7 @@ def log_dist(message, ranks=None, level=logging.INFO): def print_json_dist(message, ranks=None, path=None): + from deepspeed import comm as dist """Print message when one of following condition meets + not dist.is_initialized() diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index 9c98dc930d9e..8685f144bbb2 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -10,8 +10,6 @@ from deepspeed.utils.logging import log_dist from deepspeed import comm as dist -from deepspeed.utils import logger - try: import psutil @@ -94,6 +92,9 @@ def mean(self): def __init__(self): self.timers = {} + def get_timers(self): + return self.timers + def __call__(self, name): if name not in self.timers: self.timers[name] = self.Timer(name) @@ -143,6 +144,7 @@ def __init__( monitor_memory=False, logging_fn=None, ): + from deepspeed.utils import logger self.start_time = 0 self.end_time = 0 self.started = False @@ -233,6 +235,9 @@ def trim_mean(data, trim_percent): """ assert trim_percent >= 0.0 and trim_percent <= 1.0 n = len(data) + # Account for edge case of empty list + if len(data) == 0: + return 0 data.sort() k = int(round(n * (trim_percent))) return mean(data[k:n - k]) diff --git a/docs/_config.yml b/docs/_config.yml index 0b3ae77595e6..e09827d6b31e 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -50,6 +50,7 @@ collections: - mixture-of-experts-inference.md - model-compression.md - monitor.md + - comms-logging.md - one-cycle.md - onebit-adam.md - zero-one-adam.md diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index dc4c8c27677a..aae93d4d75b7 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -40,6 +40,8 @@ lnav: url: /docs/config-json/#flops-profiler - title: 'Monitoring' url: /docs/config-json/#monitoring-module-tensorboard-wandb-csv + - title: 'Communication Logging' + url: /docs/config-json/#communication-logging - title: 'Model Compression' url: /docs/config-json/#compression - title: 'Tutorials' @@ -83,6 +85,8 @@ lnav: url: /tutorials/MoQ-tutorial/ - title: 'Monitoring' url: /tutorials/monitor + - title: 'Communication Logging' + url: /tutorials/comms-logging - title: 'One-Cycle Schedule' url: /tutorials/one-cycle/ - title: 'One-Bit Adam' diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 642b0300f09f..11541aee9761 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1047,6 +1047,50 @@ Example of **csv_monitor** configuration: "job_name": "train_bert" } ``` + +### Communication Logging + + +DeepSpeed provides a flexible communication logging tool which can automatically detect and record communication operations launched via `deepspeed.comm`. NOTE: All logging communication calls are synchronized in order to provide accurate timing information. This may hamper performance if your model heavily uses asynchronous communication operations. + +Once the logs are populated, they can be summarized with `deepspeed.comm.log_summary()`. For more detail and example usage, see the [tutorial](/tutorials/comms-logging/) + + + + +**comms_logger**: [dictionary] + +| Fields | Value |Default | +| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | +| enabled | Whether communication logging is enabled. | `false` | +| verbose | Whether to immediately print every communication operation | `false` | +| prof_all | Whether to profile all operations. | `true` | +| debug | Appends the caller function to each communication operation's `log_name`. | `false` | +| prof_ops | A list of communication operations to log (only the specified ops will be profiled). | `[]` | + + +Example of recommended **comms_logger** configuration: + +```json +"comms_logger": { + "enabled": true, + "verbose": false, + "prof_all": true, + "debug": false +} +``` + +Example of **comms_logger** configuration for logging specific operations only: + +```json +"comms_logger": { + "enabled": true, + "verbose": false, + "prof_all": false, + "debug": false, + "prof_ops": ["all_reduce", "all_gather"] +} +``` ### Compression **Note:** **Compression** has seven different components, including layer reduction, weight quantization, activation quantization, sparse pruning, row pruning, head pruning, and channel pruning. We explain them one by one with simple json examples. Read more about how to use the DeepSpeed Compression library in our [tutorial](/tutorials/model-compression/). diff --git a/docs/_pages/training.md b/docs/_pages/training.md index 116bb0d2c697..41178d54ea43 100644 --- a/docs/_pages/training.md +++ b/docs/_pages/training.md @@ -527,6 +527,24 @@ The DeepSpeed Monitor logs live training metrics to one or more monitoring backe The Monitor can also be added to log custom metrics and client codes. Please refer to the [Monitor](/tutorials/monitor) tutorial for more details. +### Communication Logging + +DeepSpeed provides logging of all communication operations launched within `deepspeed.comm`. The communication logger can be configured in the `deepspeed_config` file as follows: + +```json +{ + "comms_logger": { + "enabled": true, + "verbose": false, + "prof_all": true, + "debug": false + } +} + +``` + +Client codes can then print a summary with a call to `deepspeed.comm.log_summary()`. For more details and example usage, see the [Communication Logging](/tutorials/comms-logging) tutorial. + ## Sparse Attention DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial. diff --git a/docs/_tutorials/comms-logging.md b/docs/_tutorials/comms-logging.md new file mode 100644 index 000000000000..52d93eda05bc --- /dev/null +++ b/docs/_tutorials/comms-logging.md @@ -0,0 +1,116 @@ +--- +title: "Communication Logging" +excerpt: "Log all DeepSpeed communication calls" +tags: profiling performance-tuning +--- + +In this tutorial, we introduce DeepSpeed communication logging and provide examples of its usage. + + - [Overview](#overview) + - [Usage](#usage) + +## Overview + +NOTE: All logging communication calls are synchronized in order to provide accurate timing information. This may hamper performance if your model heavily uses asynchronous communication operations. + +Logging communication calls is vital to ensure networking resources are fully utilized. The DeepSpeed communication logger enables the detection and logging of all communication operations launched under `deepspeed.comm`. Each communication operation can all be directly printed to the console immediately after completion (via the `verbose` config option), or a summary may be printed with a call to `deepspeed.comm.log_summary()` in the client code at the completion of training, an epoch, after N training iterations, etc. + +## Usage + +Communication logging in DeepSpeed is configured within the deepspeed [configuration file](/docs/config-json/#communication-logging). DeepSpeed will automatically log communication either all operations (`prof_all`), or user-specified operations (`prof_ops`). + + - [Configuration Setup](#configuration-setup) + - [Verbose Logging](#verbose-logging) + - [Log Summaries](#log-summaries) + +### Configuration Setup + +Communication logging can be configured in the DeepSpeed [configuration file](/docs/config-json/#communication-logging). Communication logging can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Communication Logging](/docs/config-json/#communication-logging) for details. + +```json +"comms_logger": { + "enabled": true, + "verbose": false, + "prof_all": true, + "debug": false +} +``` + +There are currently two ways to view communication log records: + +1. Print all communication operations with `verbose` config option. See [Verbose Logging](#verbose-logging) +2. (Recommended) Print log summary with `deepspeed.comm.log_summary()` function call. See [Log Summaries](#log-summaries) + +### Verbose Logging + +If the `enabled` configuration option is selected, all communication operations will be immediately printed to the console. This mode is intended for detailed debugging, and is not recommended for most users. The following is an example snippet of `verbose` output: + +``` +[2022-06-26 01:39:55,722] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: reduce_scatter_base | time (ms): 9.46 | msg size: 678.86 MB | algbw (Gbps): 1204.52 | busbw (Gbps): 1129.23 +[2022-06-26 01:39:56,470] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_base | time (ms): 0.11 | msg size: 6.0 MB | algbw (Gbps): 954.41 | busbw (Gbps): 894.76 +[2022-06-26 01:39:56,471] [INFO] [logging.py:69:log_dist] [Rank 0] rank=0 | comm op: all_gather_base | time (ms): 0.08 | msg size: 6.0 MB | algbw (Gbps): 1293.47 | busbw (Gbps): 1212.63 +``` + +For advanced users, the `debug` option will append the calling function of each communication operation to that operation's `log_name`. See [Log Summaries](#log-summaries) for an example of a `deepspeed.comm.log_summary()` call with `debug` enabled. + + +### Log Summaries + +It's recommended that users add a call to `deepspeed.comm.log_summary()` at training milestones (e.g. every epoch or N iterations). This enables high-level communication logging without having to sift through logs from `verbose`. + +The steps to add DeepSpeed communication log summaries are as follows: + +1. Modify configuration file with desired settings +2. (Optional) If your application contains `torch.distributed` calls that you wish to log, import `deepspeed.comm` package and modify `torch.distributed` calls to use `deepspeed.comm` (Note: The `deepspeed.comm` collective and pt2pt APIs exactly match `torch.distributed`) +3. Call `deepspeed.comm.log_summary` + +For example usage, see the following modified [DeepSpeedExamples/cifar](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) example: + +```python +# Step 2: (Optional) Import deepspeed.comm +import deepspeed.comm as dist + +# Note that any communication operations using `import torch.distributed as dist` calls can remain unchanged, and will be automatically logged under deepspeed.comm! +dist.all_reduce(tensor) + +for epoch in range(2): + + running_loss = 0.0 + for i, data in enumerate(trainloader): + pre = time.time() + inputs, labels = data[0].to(model_engine.local_rank), data[1].to( + model_engine.local_rank) + if fp16: + inputs = inputs.half() + outputs = model_engine(inputs) + loss = criterion(outputs, labels) + + model_engine.backward(loss) + model_engine.step() + post = time.time() + # Step 3: Call `deepspeed.comm.log_summary()` + dist.log_summary() +``` + +The following is a truncated example output of `deepspeed.comm.log_summary()` at the end of 10 iterations of Megatron-DeepSpeed with ZeRO-3: + +``` +Comm. Op Message Size Count Total Latency(ms) Avg Latency(ms) tput_avg (Gbps) busbw_avg (Gbps) +broadcast + 2.0 KB 146 11.12 0.08 0.43 0.41 + 98.25 MB 1 8317.12 8317.12 0.20 0.19 +reduce_scatter_base + 678.86 MB 40 602.29 9.69 1468.06 1376.31 +``` + + +And the following is a call to `deepspeed.comm.log_summary` under the same configuration with `debug` enabled: + +``` +Comm. Op Message Size Count Total Latency(ms) Avg Latency(ms) tput_avg (Gbps) busbw_avg (Gbps) +broadcast | [Caller Func: _broadcast_model] + 2.0 KB 146 9.39 0.06 0.52 0.48 + 98.25 MB 1 8540.60 8540.60 0.19 0.18 +reduce_scatter_base | [Caller Func: reduce_scatter_fn] + 678.86 MB 80 1527.17 13.94 1211.75 1136.01 +```