Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ReduceOp(Enum):


def _configure_using_config_file(config):
if config.comms_logger_enabled:
if config.enabled:
comms_logger.configure(config)


Expand Down
34 changes: 14 additions & 20 deletions deepspeed/comm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,25 @@

# DeepSpeed Team

from pydantic import BaseModel
from .constants import *
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


class CommsConfig(BaseModel):
def get_comms_config(param_dict):
return DeepSpeedCommsConfig(**param_dict.get("comms_logger", {}))

class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'

class DeepSpeedCommsConfig(DeepSpeedConfigModel):
enabled: bool = False
""" Whether communication logging is enabled. """

class CommsLoggerConfig(CommsConfig):
enabled: bool = COMMS_LOGGER_ENABLED_DEFAULT
prof_all: bool = COMMS_LOGGER_PROF_ALL_DEFAULT
prof_ops: list = COMMS_LOGGER_PROF_OPS_DEFAULT
verbose: bool = COMMS_LOGGER_VERBOSE_DEFAULT
debug: bool = COMMS_LOGGER_DEBUG_DEFAULT
prof_all: bool = True
""" Whether to profile all operations. """

prof_ops: list = []
""" A list of communication operations to log (only the specified ops will be profiled). """

class DeepSpeedCommsConfig:
verbose: bool = False
""" Whether to immediately print every communication operation. """

def __init__(self, ds_config):
self.comms_logger_enabled = 'comms_logger' in ds_config

if self.comms_logger_enabled:
self.comms_logger = CommsLoggerConfig(**ds_config['comms_logger'])
debug: bool = False
""" Appends the caller function to each communication operation's `log_name`. """
37 changes: 0 additions & 37 deletions deepspeed/comm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,3 @@

DEFAULT_AML_MASTER_PORT = "54965"
DEFAULT_AML_NCCL_SOCKET_IFNAME = "^docker0,lo"

#########################################
# Comms Logger
#########################################
# Comms Logger. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
COMMS_LOGGER_FORMAT = '''
The Comms Logger can be specified as:
"comms_logger": {
"enabled": true,
"verbose": false,
"prof_all": true,
"debug": false,
"prof_ops": ["all_reduce", "custom_all_reduce_name"]
}
'''
COMMS_LOGGER = "comms_logger"

# Comms logger enable signal
COMMS_LOGGER_ENABLED = "enabled"
COMMS_LOGGER_ENABLED_DEFAULT = False

# Comms logger verbose signal
COMMS_LOGGER_VERBOSE = "verbose"
COMMS_LOGGER_VERBOSE_DEFAULT = False

# comms logger profile all ops signal
COMMS_LOGGER_PROF_ALL = "prof_all"
COMMS_LOGGER_PROF_ALL_DEFAULT = True

# comms logger show all ops signal
COMMS_LOGGER_DEBUG = "debug"
COMMS_LOGGER_DEBUG_DEFAULT = False

# comms logger profile specific ops in list
COMMS_LOGGER_PROF_OPS = "prof_ops"
COMMS_LOGGER_PROF_OPS_DEFAULT = []
8 changes: 4 additions & 4 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
)
from .zero.config import get_zero_config, ZeroStageEnum
from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..comm.config import get_comms_config
from ..monitor.config import DeepSpeedMonitorConfig

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
Expand Down Expand Up @@ -784,8 +784,8 @@ def _initialize_params(self, param_dict):

self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig(param_dict)

self.comms_config = DeepSpeedCommsConfig(param_dict)
self.monitor_config = get_monitor_config(param_dict)
self.comms_config = get_comms_config(param_dict)
self.monitor_config = DeepSpeedMonitorConfig(param_dict)

self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
Expand Down
17 changes: 6 additions & 11 deletions deepspeed/utils/comms_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,16 @@ def calc_bw_log(comm_op, size, duration):
class CommsLogger:

def __init__(self):
from deepspeed.comm.constants import COMMS_LOGGER_VERBOSE_DEFAULT, COMMS_LOGGER_DEBUG_DEFAULT, COMMS_LOGGER_PROF_OPS_DEFAULT, COMMS_LOGGER_PROF_ALL_DEFAULT, COMMS_LOGGER_ENABLED_DEFAULT
self.enabled = False
self.comms_dict = {}
self.verbose = COMMS_LOGGER_VERBOSE_DEFAULT
self.debug = COMMS_LOGGER_DEBUG_DEFAULT
self.prof_ops = COMMS_LOGGER_PROF_OPS_DEFAULT
self.prof_all = COMMS_LOGGER_PROF_ALL_DEFAULT
self.enabled = COMMS_LOGGER_ENABLED_DEFAULT

def configure(self, comms_config):
self.enabled = comms_config.comms_logger_enabled
self.enabled = comms_config.enabled
if self.enabled:
self.verbose = comms_config.comms_logger.verbose
self.debug = comms_config.comms_logger.debug
self.prof_ops = comms_config.comms_logger.prof_ops
self.prof_all = comms_config.comms_logger.prof_all
self.verbose = comms_config.verbose
self.debug = comms_config.debug
self.prof_ops = comms_config.prof_ops
self.prof_all = comms_config.prof_all

# There are three settings for the op profiler:
# - Global profiling (profile all comms)
Expand Down
18 changes: 18 additions & 0 deletions docs/code-docs/source/comms-logging.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Communication Logging
=====================

DeepSpeed provides a flexible communication logging tool which can
automatically detect and record communication operations launched via
`deepspeed.comm`. NOTE: All logging communication calls are synchronized in
order to provide accurate timing information. This may hamper performance if
your model heavily uses asynchronous communication operations.

Once the logs are populated, they can be summarized with
`deepspeed.comm.log_summary()`. For more detail and example usage, see the
[tutorial](https://www.deepspeed.ai/tutorials/comms-logging/).

The behavior of communication logging can be controlled with values in the
`comms_logger` dictionary of the main DeepSpeed config:

.. _DeepSpeedCommsConfig:
.. autopydantic_model:: deepspeed.comm.config.DeepSpeedCommsConfig
7 changes: 7 additions & 0 deletions docs/code-docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ Memory Usage

memory

Comms Logging
-------------
.. toctree::
:maxdepth: 2

comms-logging

Monitoring
----------
.. toctree::
Expand Down
115 changes: 115 additions & 0 deletions tests/unit/comm/test_comms_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import deepspeed.comm as dist
import deepspeed

from unit.common import DistributedTest
from unit.simple_model import SimpleModel

import time, logging, os


def within_range(val, target, tolerance):
print(f'prof_on: {val}, prof_off: {target}')
return val - target / target < tolerance


# This tolerance seems tight enough to catch comm logging overhead while loose enough to allow for comm instability.
# Can increase if github runner comm instability leads to many false negatives.
TOLERANCE = 0.05


class TestCommsLoggingOverhead(DistributedTest):
world_size = [2, 4]

def test(self):
# Need comm warmups, or else whoever communicates first loses
NUM_WARMUPS = 5
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"comms_logger": {
"enabled": False,
"verbose": True,
"prof_all": True,
"debug": False
}
}

# dummy model
model = SimpleModel(4)
model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters())
x = torch.ones(4, 2**15).cuda() * (dist.get_rank() + 1)

for i in range(NUM_WARMUPS):
dist.all_reduce(x)

# Time allreduce without logging
start = time.time()
dist.all_reduce(x, prof=False)
torch.cuda.synchronize()
time_prof_off = time.time() - start
dist.all_reduce(torch.Tensor([time_prof_off]).cuda(),
prof=False,
op=dist.ReduceOp.AVG)

# Time allreduce with logging
start = time.time()
dist.all_reduce(x, prof=True)
torch.cuda.synchronize()
time_prof_on = time.time() - start
dist.all_reduce(torch.Tensor([time_prof_on]).cuda(),
prof=False,
op=dist.ReduceOp.AVG)

# Ensure logging doesn't add significant overhead
assert within_range(time_prof_on, time_prof_off, tolerance=TOLERANCE)


class TestNumLoggingCalls(DistributedTest):
world_size = [2, 4]

def test(self, class_tmpdir):
num_all_reduce_calls = 4
num_broadcast_calls = 2

# Have the DeepSpeed logger output to both stdout and file so that we can verify log output
file_path = os.path.join(class_tmpdir,
f"comm_output_{int(os.environ['WORLD_SIZE'])}.log")
DSLogger = logging.getLogger('DeepSpeed')
fileHandler = logging.FileHandler(file_path)
DSLogger.addHandler(fileHandler)

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"comms_logger": {
"enabled": True,
"verbose": True,
"prof_all": True,
"debug": False
}
}

# dummy model so that config options are picked up
model = SimpleModel(4)
model, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters())

x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)

# Make comm calls
for i in range(num_all_reduce_calls):
dist.all_reduce(x, log_name="all_reduce_test")
for i in range(num_broadcast_calls):
dist.broadcast(x, 0, log_name="broadcast_test")

# Count the number of logs
with open(file_path, 'r') as f:
log_output = f.read()
num_all_reduce_logs = log_output.count('all_reduce_test')
num_broadcast_logs = log_output.count('broadcast_test')

# Ensure all comm calls are logged
assert num_all_reduce_logs == num_all_reduce_calls
assert num_broadcast_logs == num_broadcast_calls