diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index 61b001ad0662..0339dd5a4c75 100644 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -68,7 +68,7 @@ class ReduceOp(Enum): def _configure_using_config_file(config): - if config.comms_logger_enabled: + if config.enabled: comms_logger.configure(config) diff --git a/deepspeed/comm/config.py b/deepspeed/comm/config.py index 138badebe5a9..96a5c9472346 100644 --- a/deepspeed/comm/config.py +++ b/deepspeed/comm/config.py @@ -3,31 +3,25 @@ # DeepSpeed Team -from pydantic import BaseModel -from .constants import * +from deepspeed.runtime.config_utils import DeepSpeedConfigModel -class CommsConfig(BaseModel): +def get_comms_config(param_dict): + return DeepSpeedCommsConfig(**param_dict.get("comms_logger", {})) - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - extra = 'forbid' +class DeepSpeedCommsConfig(DeepSpeedConfigModel): + enabled: bool = False + """ Whether communication logging is enabled. """ -class CommsLoggerConfig(CommsConfig): - enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT - prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT - prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT - verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT - debug: bool = COMMS_LOGGER_DEBUG_DEFAULT + prof_all: bool = True + """ Whether to profile all operations. """ + prof_ops: list = [] + """ A list of communication operations to log (only the specified ops will be profiled). """ -class DeepSpeedCommsConfig: + verbose: bool = False + """ Whether to immediately print every communication operation. """ - def __init__(self, ds_config): - self.comms_logger_enabled = 'comms_logger' in ds_config - - if self.comms_logger_enabled: - self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger']) + debug: bool = False + """ Appends the caller function to each communication operation's `log_name`. """ diff --git a/deepspeed/comm/constants.py b/deepspeed/comm/constants.py index ab309247befe..3466565cf145 100644 --- a/deepspeed/comm/constants.py +++ b/deepspeed/comm/constants.py @@ -10,40 +10,3 @@ DEFAULT_AML_MASTER_PORT = "54965" DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo" - -######################################### -# Comms Logger -######################################### -# Comms Logger. By default, this feature is not enabled. -# Users can configure in ds_config.json as below example: -COMMS_LOGGER_FORMAT = ''' -The Comms Logger can be specified as: -"comms_logger": { - "enabled": true, - "verbose": false, - "prof_all": true, - "debug": false, - "prof_ops": ["all_reduce", "custom_all_reduce_name"] -} -''' -COMMS_LOGGER = "comms_logger" - -# Comms logger enable signal -COMMS_LOGGER_ENABLED = "enabled" -COMMS_LOGGER_ENABLED_DEFAULT = False - -# Comms logger verbose signal -COMMS_LOGGER_VERBOSE = "verbose" -COMMS_LOGGER_VERBOSE_DEFAULT = False - -# comms logger profile all ops signal -COMMS_LOGGER_PROF_ALL = "prof_all" -COMMS_LOGGER_PROF_ALL_DEFAULT = True - -# comms logger show all ops signal -COMMS_LOGGER_DEBUG = "debug" -COMMS_LOGGER_DEBUG_DEFAULT = False - -# comms logger profile specific ops in list -COMMS_LOGGER_PROF_OPS = "prof_ops" -COMMS_LOGGER_PROF_OPS_DEFAULT = [] diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 3c202a9acd07..1a3c14e72854 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -27,8 +27,8 @@ ) from .zero.config import get_zero_config, ZeroStageEnum from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig -from ..comm.config import DeepSpeedCommsConfig -from ..monitor.config import get_monitor_config +from ..comm.config import get_comms_config +from ..monitor.config import DeepSpeedMonitorConfig from deepspeed import comm as dist from deepspeed.runtime.config_utils import DeepSpeedConfigModel @@ -784,8 +784,8 @@ def _initialize_params(self, param_dict): self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(param_dict) - self.comms_config = DeepSpeedCommsConfig(param_dict) - self.monitor_config = get_monitor_config(param_dict) + self.comms_config = get_comms_config(param_dict) + self.monitor_config = DeepSpeedMonitorConfig(param_dict) self.gradient_clipping = get_gradient_clipping(param_dict) self.fp16_enabled = get_fp16_enabled(param_dict) diff --git a/deepspeed/utils/comms_logging.py b/deepspeed/utils/comms_logging.py index 2400fa55b20e..a4b4193b8f33 100644 --- a/deepspeed/utils/comms_logging.py +++ b/deepspeed/utils/comms_logging.py @@ -61,21 +61,16 @@ def calc_bw_log(comm_op, size, duration): class CommsLogger: def __init__(self): - from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT + self.enabled = False self.comms_dict = {} - self.verbose = COMMS_LOGGER_VERBOSE_DEFAULT - self.debug = COMMS_LOGGER_DEBUG_DEFAULT - self.prof_ops = COMMS_LOGGER_PROF_OPS_DEFAULT - self.prof_all = COMMS_LOGGER_PROF_ALL_DEFAULT - self.enabled = COMMS_LOGGER_ENABLED_DEFAULT def configure(self, comms_config): - self.enabled = comms_config.comms_logger_enabled + self.enabled = comms_config.enabled if self.enabled: - self.verbose = comms_config.comms_logger.verbose - self.debug = comms_config.comms_logger.debug - self.prof_ops = comms_config.comms_logger.prof_ops - self.prof_all = comms_config.comms_logger.prof_all + self.verbose = comms_config.verbose + self.debug = comms_config.debug + self.prof_ops = comms_config.prof_ops + self.prof_all = comms_config.prof_all # There are three settings for the op profiler: # - Global profiling (profile all comms) diff --git a/docs/code-docs/source/comms-logging.rst b/docs/code-docs/source/comms-logging.rst new file mode 100644 index 000000000000..85271de6f1ad --- /dev/null +++ b/docs/code-docs/source/comms-logging.rst @@ -0,0 +1,18 @@ +Communication Logging +===================== + +DeepSpeed provides a flexible communication logging tool which can +automatically detect and record communication operations launched via +`deepspeed.comm`. NOTE: All logging communication calls are synchronized in +order to provide accurate timing information. This may hamper performance if +your model heavily uses asynchronous communication operations. + +Once the logs are populated, they can be summarized with +`deepspeed.comm.log_summary()`. For more detail and example usage, see the +[tutorial](https://www.deepspeed.ai/tutorials/comms-logging/). + +The behavior of communication logging can be controlled with values in the +`comms_logger` dictionary of the main DeepSpeed config: + +.. _DeepSpeedCommsConfig: +.. autopydantic_model:: deepspeed.comm.config.DeepSpeedCommsConfig diff --git a/docs/code-docs/source/index.rst b/docs/code-docs/source/index.rst index 67d5aa5fe9fb..490ec0a28815 100644 --- a/docs/code-docs/source/index.rst +++ b/docs/code-docs/source/index.rst @@ -98,6 +98,13 @@ Memory Usage memory +Comms Logging +------------- +.. toctree:: + :maxdepth: 2 + + comms-logging + Monitoring ---------- .. toctree:: diff --git a/tests/unit/comm/test_comms_logger.py b/tests/unit/comm/test_comms_logger.py new file mode 100644 index 000000000000..bfb8866b88dd --- /dev/null +++ b/tests/unit/comm/test_comms_logger.py @@ -0,0 +1,115 @@ +import torch +import deepspeed.comm as dist +import deepspeed + +from unit.common import DistributedTest +from unit.simple_model import SimpleModel + +import time, logging, os + + +def within_range(val, target, tolerance): + print(f'prof_on: {val}, prof_off: {target}') + return val - target / target < tolerance + + +# This tolerance seems tight enough to catch comm logging overhead while loose enough to allow for comm instability. +# Can increase if github runner comm instability leads to many false negatives. +TOLERANCE = 0.05 + + +class TestCommsLoggingOverhead(DistributedTest): + world_size = [2, 4] + + def test(self): + # Need comm warmups, or else whoever communicates first loses + NUM_WARMUPS = 5 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "comms_logger": { + "enabled": False, + "verbose": True, + "prof_all": True, + "debug": False + } + } + + # dummy model + model = SimpleModel(4) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters()) + x = torch.ones(4, 2**15).cuda() * (dist.get_rank() + 1) + + for i in range(NUM_WARMUPS): + dist.all_reduce(x) + + # Time allreduce without logging + start = time.time() + dist.all_reduce(x, prof=False) + torch.cuda.synchronize() + time_prof_off = time.time() - start + dist.all_reduce(torch.Tensor([time_prof_off]).cuda(), + prof=False, + op=dist.ReduceOp.AVG) + + # Time allreduce with logging + start = time.time() + dist.all_reduce(x, prof=True) + torch.cuda.synchronize() + time_prof_on = time.time() - start + dist.all_reduce(torch.Tensor([time_prof_on]).cuda(), + prof=False, + op=dist.ReduceOp.AVG) + + # Ensure logging doesn't add significant overhead + assert within_range(time_prof_on, time_prof_off, tolerance=TOLERANCE) + + +class TestNumLoggingCalls(DistributedTest): + world_size = [2, 4] + + def test(self, class_tmpdir): + num_all_reduce_calls = 4 + num_broadcast_calls = 2 + + # Have the DeepSpeed logger output to both stdout and file so that we can verify log output + file_path = os.path.join(class_tmpdir, + f"comm_output_{int(os.environ['WORLD_SIZE'])}.log") + DSLogger = logging.getLogger('DeepSpeed') + fileHandler = logging.FileHandler(file_path) + DSLogger.addHandler(fileHandler) + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "comms_logger": { + "enabled": True, + "verbose": True, + "prof_all": True, + "debug": False + } + } + + # dummy model so that config options are picked up + model = SimpleModel(4) + model, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters()) + + x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1) + + # Make comm calls + for i in range(num_all_reduce_calls): + dist.all_reduce(x, log_name="all_reduce_test") + for i in range(num_broadcast_calls): + dist.broadcast(x, 0, log_name="broadcast_test") + + # Count the number of logs + with open(file_path, 'r') as f: + log_output = f.read() + num_all_reduce_logs = log_output.count('all_reduce_test') + num_broadcast_logs = log_output.count('broadcast_test') + + # Ensure all comm calls are logged + assert num_all_reduce_logs == num_all_reduce_calls + assert num_broadcast_logs == num_broadcast_calls