diff --git a/deepspeed/monitor/__init__.py b/deepspeed/monitor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py new file mode 100644 index 000000000000..b854b8580023 --- /dev/null +++ b/deepspeed/monitor/config.py @@ -0,0 +1,50 @@ +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +from typing import Optional +from deepspeed.runtime.config_utils import get_scalar_param +from pydantic import BaseModel, validator, ValidationError, create_model +from .constants import * + + +class MonitorConfig(BaseModel): + class Config: + validate_all = True + validate_assignment = True + use_enum_values = True + extra = 'forbid' + + +class TensorBoardConfig(MonitorConfig): + enabled: bool = TENSORBOARD_ENABLED_DEFAULT + output_path: str = TENSORBOARD_OUTPUT_PATH_DEFAULT + job_name: str = TENSORBOARD_JOB_NAME_DEFAULT + + +class WandbConfig(MonitorConfig): + enabled: bool = WANDB_ENABLED_DEFAULT + group: str = WANDB_GROUP_NAME_DEFAULT + team: str = WANDB_TEAM_NAME_DEFAULT + project: str = WANDB_PROJECT_NAME_DEFAULT + + +class CSVConfig(MonitorConfig): + enabled: bool = CSV_MONITOR_ENABLED_DEFAULT + output_path: str = CSV_MONITOR_OUTPUT_PATH_DEFAULT + job_name: str = CSV_MONITOR_JOB_NAME_DEFAULT + + +class DeepSpeedMonitorConfig: + def __init__(self, ds_config): + self.tensorboard_enabled = 'tensorboard' in ds_config + self.wandb_enabled = 'wandb' in ds_config + self.csv_monitor_enabled = 'csv_monitor' in ds_config + + if self.tensorboard_enabled: + self.tensorboard_config = TensorBoardConfig(**ds_config['tensorboard']) + if self.wandb_enabled: + self.wandb_config = WandbConfig(**ds_config['wandb']) + if self.csv_monitor_enabled: + self.csv_monitor_config = CSVConfig(**ds_config['csv_monitor']) diff --git a/deepspeed/monitor/constants.py b/deepspeed/monitor/constants.py new file mode 100644 index 000000000000..95cb970175c4 --- /dev/null +++ b/deepspeed/monitor/constants.py @@ -0,0 +1,85 @@ +######################################### +# Tensorboard +######################################### +# Tensorboard. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +TENSORBOARD_FORMAT = ''' +Tensorboard can be specified as: +"tensorboard": { + "enabled": true, + "output_path": "/home/myname/foo", + "job_name": "model_lr2e-5_epoch3_seed2_seq64" +} +''' +TENSORBOARD = "tensorboard" + +# Tensorboard enable signal +TENSORBOARD_ENABLED = "enabled" +TENSORBOARD_ENABLED_DEFAULT = False + +# Tensorboard output path +TENSORBOARD_OUTPUT_PATH = "output_path" +TENSORBOARD_OUTPUT_PATH_DEFAULT = "" + +# Tensorboard job name +TENSORBOARD_JOB_NAME = "job_name" +TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName" + +######################################### +# Wandb +######################################### +# Wandb. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +WANDB_FORMAT = ''' +Wandb can be specified as: +"wandb": { + "enabled": true, + "team_name": "deepspeed" + "project_name": "zero" + "group_name": "zero: stage 3", +} +''' +WANDB = "wandb" + +# Wandb enable signal +WANDB_ENABLED = "enabled" +WANDB_ENABLED_DEFAULT = False + +# Wandb team +WANDB_TEAM_NAME = "team" +WANDB_TEAM_NAME_DEFAULT = None + +# Wandb project +WANDB_PROJECT_NAME = "project" +WANDB_PROJECT_NAME_DEFAULT = "deepspeed" + +# Wandb group +WANDB_GROUP_NAME = "group" +WANDB_GROUP_NAME_DEFAULT = None + +######################################### +# csv monitor +######################################### +# Basic CSV monitor. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +CSV_FORMAT = ''' +The basic csv monitor can be specified as: +"csv_monitor": { + "enabled": true, + "output_path": "/home/myname/foo", + "job_name": "model_lr2e-5_epoch3_seed2_seq64" +} +''' +CSV_MONITOR = "csv_monitor" + +# csv monitor enable signal +CSV_MONITOR_ENABLED = "enabled" +CSV_MONITOR_ENABLED_DEFAULT = False + +# csv monitor output path +CSV_MONITOR_OUTPUT_PATH = "output_path" +CSV_MONITOR_OUTPUT_PATH_DEFAULT = "" + +# csv_monitor job name +CSV_MONITOR_JOB_NAME = "job_name" +CSV_MONITOR_JOB_NAME_DEFAULT = "DeepSpeedJobName" diff --git a/deepspeed/monitor/csv_monitor.py b/deepspeed/monitor/csv_monitor.py new file mode 100644 index 000000000000..b2b05260e445 --- /dev/null +++ b/deepspeed/monitor/csv_monitor.py @@ -0,0 +1,62 @@ +from .monitor import Monitor +import os + +import deepspeed.comm as dist + + +class csvMonitor(Monitor): + def __init__(self, monitor_config): + super().__init__(monitor_config) + import csv + self.filenames = [] + self.enabled = monitor_config.csv_monitor_config.enabled + self.output_path = monitor_config.csv_monitor_config.output_path + self.job_name = monitor_config.csv_monitor_config.job_name + self.log_dir = self.setup_log_dir() + + def setup_log_dir(self, base=os.path.join(os.path.expanduser("~"), "csv_monitor")): + if self.enabled and dist.get_rank() == 0: + if self.output_path is not None: + log_dir = os.path.join(self.output_path, self.job_name) + # NOTE: This code path currently is never used since the default tensorboard_output_path is an empty string and not None. Saving it in case we want this functionality in the future. + else: + if "DLWS_JOB_ID" in os.environ: + infra_job_id = os.environ["DLWS_JOB_ID"] + elif "DLTS_JOB_ID" in os.environ: + infra_job_id = os.environ["DLTS_JOB_ID"] + else: + infra_job_id = "unknown-job-id" + + csv_monitor_dir_name = os.path.join(infra_job_id, "logs") + log_dir = os.path.join(base, csv_monitor_dir_name, self.job_name) + os.makedirs(log_dir, exist_ok=True) + return log_dir + + def write_events(self, event_list): + if self.enabled and dist.get_rank() == 0: + import csv + # We assume each event_list element is a tensorboard-style tuple in the format: (log_name: String, value, step: Int) + for event in event_list: + log_name = event[0] + value = event[1] + step = event[2] + + # Set the header to the log_name + # Need this check because the deepspeed engine currently formats log strings to separate with '/' + if '/' in log_name: + record_splits = log_name.split('/') + header = record_splits[len(record_splits) - 1] + else: + header = log_name + + # sanitize common naming conventions into filename + filename = log_name.replace('/', '_').replace(' ', '_') + fname = self.log_dir + '/' + filename + '.csv' + + # Open file and record event. Insert header if this is the first time writing + with open(fname, 'a+') as csv_monitor_file: + csv_monitor_writer = csv.writer(csv_monitor_file) + if filename not in self.filenames: + self.filenames.append(filename) + csv_monitor_writer.writerow(['step', header]) + csv_monitor_writer.writerow([step, value]) diff --git a/deepspeed/monitor/monitor.py b/deepspeed/monitor/monitor.py new file mode 100644 index 000000000000..a5ac271861ff --- /dev/null +++ b/deepspeed/monitor/monitor.py @@ -0,0 +1,47 @@ +""" + Support different forms of monitoring such as wandb and tensorboard +""" + +from abc import ABC, abstractmethod +import deepspeed.comm as dist + + +class Monitor(ABC): + @abstractmethod + def __init__(self, monitor_config): + self.monitor_config = monitor_config + + @abstractmethod + def write_events(self, event_list): + pass + + +from .wandb import WandbMonitor +from .tensorboard import TensorBoardMonitor +from .csv_monitor import csvMonitor + + +class MonitorMaster(Monitor): + def __init__(self, monitor_config): + super().__init__(monitor_config) + self.tb_monitor = None + self.wandb_monitor = None + self.csv_monitor = None + self.enabled = monitor_config.tensorboard_enabled or monitor_config.csv_monitor_enabled or monitor_config.wandb_enabled + + if dist.get_rank() == 0: + if monitor_config.tensorboard_enabled: + self.tb_monitor = TensorBoardMonitor(monitor_config) + if monitor_config.wandb_enabled: + self.wandb_monitor = WandbMonitor(monitor_config) + if monitor_config.csv_monitor_enabled: + self.csv_monitor = csvMonitor(monitor_config) + + def write_events(self, event_list): + if dist.get_rank() == 0: + if self.tb_monitor is not None: + self.tb_monitor.write_events(event_list) + if self.wandb_monitor is not None: + self.wandb_monitor.write_events(event_list) + if self.csv_monitor is not None: + self.csv_monitor.write_events(event_list) diff --git a/deepspeed/monitor/tensorboard.py b/deepspeed/monitor/tensorboard.py new file mode 100644 index 000000000000..447143e53b05 --- /dev/null +++ b/deepspeed/monitor/tensorboard.py @@ -0,0 +1,52 @@ +from .utils import check_tb_availability +from .monitor import Monitor +import os + +import deepspeed.comm as dist + + +class TensorBoardMonitor(Monitor): + def __init__(self, monitor_config): + super().__init__(monitor_config) + check_tb_availability() + + self.summary_writer = None + self.enabled = monitor_config.tensorboard_config.enabled + self.output_path = monitor_config.tensorboard_config.output_path + self.job_name = monitor_config.tensorboard_config.job_name + + if self.enabled and dist.get_rank() == 0: + self.get_summary_writer() + + def get_summary_writer(self, + base=os.path.join(os.path.expanduser("~"), + "tensorboard")): + if self.enabled and dist.get_rank() == 0: + from torch.utils.tensorboard import SummaryWriter + if self.output_path is not None: + log_dir = os.path.join(self.output_path, self.job_name) + # NOTE: This code path currently is never used since the default output_path is an empty string and not None. Saving it in case we want this functionality in the future. + else: + if "DLWS_JOB_ID" in os.environ: + infra_job_id = os.environ["DLWS_JOB_ID"] + elif "DLTS_JOB_ID" in os.environ: + infra_job_id = os.environ["DLTS_JOB_ID"] + else: + infra_job_id = "unknown-job-id" + + summary_writer_dir_name = os.path.join(infra_job_id, "logs") + log_dir = os.path.join(base, summary_writer_dir_name, self.output_path) + os.makedirs(log_dir, exist_ok=True) + self.summary_writer = SummaryWriter(log_dir=log_dir) + return self.summary_writer + + def write_events(self, event_list, flush=True): + if self.enabled and self.summary_writer is not None and dist.get_rank() == 0: + for event in event_list: + self.summary_writer.add_scalar(*event) + if flush: + self.summary_writer.flush() + + def flush(self): + if self.enabled and self.summary_writer is not None and dist.get_rank() == 0: + self.summary_writer.flush() diff --git a/deepspeed/monitor/utils.py b/deepspeed/monitor/utils.py new file mode 100644 index 000000000000..f519a71823a9 --- /dev/null +++ b/deepspeed/monitor/utils.py @@ -0,0 +1,18 @@ +def check_tb_availability(): + try: + # torch.utils.tensorboard will fail if `tensorboard` is not available, + # see their docs for more details: https://pytorch.org/docs/1.8.0/tensorboard.html + import tensorboard + except ImportError: + print('If you want to use tensorboard logging, please `pip install tensorboard`') + raise + + +def check_wandb_availability(): + try: + import wandb + except ImportError: + print( + 'If you want to use wandb logging, please `pip install wandb` and follow the instructions at https://docs.wandb.ai/quickstart' + ) + raise diff --git a/deepspeed/monitor/wandb.py b/deepspeed/monitor/wandb.py new file mode 100644 index 000000000000..63f5879633b5 --- /dev/null +++ b/deepspeed/monitor/wandb.py @@ -0,0 +1,32 @@ +from .utils import check_wandb_availability +from .monitor import Monitor + +import deepspeed.comm as dist + + +class WandbMonitor(Monitor): + def __init__(self, monitor_config): + super().__init__(monitor_config) + check_wandb_availability() + import wandb + + self.enabled = monitor_config.wandb_config.enabled + self.group = monitor_config.wandb_config.group + self.team = monitor_config.wandb_config.team + self.project = monitor_config.wandb_config.project + + if self.enabled and dist.get_rank() == 0: + wandb.init(project=self.project, group=self.group, entity=self.team) + + def log(self, data, step=None, commit=None, sync=None): + if self.enabled and dist.get_rank() == 0: + import wandb + return wandb.log(data, step=step, commit=commit, sync=sync) + + def write_events(self, event_list): + if self.enabled and dist.get_rank() == 0: + for event in event_list: + label = event[0] + value = event[1] + step = event[2] + self.log({label: value}, step=step) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 76da68bc4190..46038f448375 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -24,6 +24,7 @@ from .zero.config import DeepSpeedZeroConfig from .zero.constants import * from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig +from ..monitor.config import DeepSpeedMonitorConfig import deepspeed.comm as dist @@ -617,15 +618,6 @@ def get_memory_breakdown(param_dict): return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT) -def get_tensorboard_enabled(param_dict): - if TENSORBOARD in param_dict.keys(): - return get_scalar_param(param_dict[TENSORBOARD], - TENSORBOARD_ENABLED, - TENSORBOARD_ENABLED_DEFAULT) - else: - return False - - def get_eigenvalue_config(param_dict): if get_quantize_enabled(param_dict): param_dict = param_dict[QUANTIZE_TRAINING] @@ -726,26 +718,6 @@ def get_eigenvalue_layer_num(param_dict): return EIGENVALUE_LAYER_NUM_DEFAULT -def get_tensorboard_output_path(param_dict): - if get_tensorboard_enabled(param_dict): - return get_scalar_param( - param_dict[TENSORBOARD], - TENSORBOARD_OUTPUT_PATH, - TENSORBOARD_OUTPUT_PATH_DEFAULT, - ) - else: - return TENSORBOARD_OUTPUT_PATH_DEFAULT - - -def get_tensorboard_job_name(param_dict): - if get_tensorboard_enabled(param_dict): - return get_scalar_param(param_dict[TENSORBOARD], - TENSORBOARD_JOB_NAME, - TENSORBOARD_JOB_NAME_DEFAULT) - else: - return TENSORBOARD_JOB_NAME_DEFAULT - - def get_checkpoint_params(param_dict): return param_dict.get(CHECKPOINT, {}) @@ -899,6 +871,8 @@ def _initialize_params(self, param_dict): self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig( param_dict) + self.monitor_config = DeepSpeedMonitorConfig(param_dict) + self.gradient_clipping = get_gradient_clipping(param_dict) self.fp16_enabled = get_fp16_enabled(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) @@ -945,9 +919,6 @@ def _initialize_params(self, param_dict): | self.flops_profiler_config.enabled) self.memory_breakdown = get_memory_breakdown(param_dict) self.autotuning_config = DeepSpeedAutotuningConfig(param_dict) - self.tensorboard_enabled = get_tensorboard_enabled(param_dict) - self.tensorboard_output_path = get_tensorboard_output_path(param_dict) - self.tensorboard_job_name = get_tensorboard_job_name(param_dict) ( self.eigenvalue_enabled, diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index ee2e51c6109f..88b055b3e210 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -282,33 +282,6 @@ MEMORY_BREAKDOWN = 'memory_breakdown' MEMORY_BREAKDOWN_DEFAULT = False -######################################### -# Tensorboard -######################################### -# Tensorboard. By default, this feature is not enabled. -# Users can configure in ds_config.json as below example: -TENSORBOARD_FORMAT = ''' -Tensorboard can be specified as: -"tensorboard": { - "enabled": true, - "output_path": "/home/myname/foo", - "job_name": "model_lr2e-5_epoch3_seed2_seq64" -} -''' -TENSORBOARD = "tensorboard" - -# Tensorboard enable signal -TENSORBOARD_ENABLED = "enabled" -TENSORBOARD_ENABLED_DEFAULT = False - -# Tensorboard output path -TENSORBOARD_OUTPUT_PATH = "output_path" -TENSORBOARD_OUTPUT_PATH_DEFAULT = "" - -# Tensorboard job name -TENSORBOARD_JOB_NAME = "job_name" -TENSORBOARD_JOB_NAME_DEFAULT = "DeepSpeedJobName" - ######################################### # Eigenvalue ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 60d1a6140ba7..3ed13e75d0ed 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -52,6 +52,7 @@ from deepspeed.comm.comm import init_distributed from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.debug import debug_extract_module_and_param_names +from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop from deepspeed.runtime.utils import clip_grad_norm_ from deepspeed.runtime.eigenvalue import Eigenvalue @@ -260,8 +261,7 @@ def __init__( self._set_distributed_vars(args) - if self.tensorboard_enabled() and self.global_rank == 0: - self.summary_writer = self.get_summary_writer() + self.monitor = MonitorMaster(self._config.monitor_config) see_memory_usage( f"DeepSpeed Engine: Before configure distributed model", @@ -501,54 +501,6 @@ def curriculum_enabled(self): def curriculum_params(self): return self._config.curriculum_params - def tensorboard_enabled(self): - return self._config.tensorboard_enabled - - def tensorboard_output_path(self): - return self._config.tensorboard_output_path - - def tensorboard_job_name(self): - return self._config.tensorboard_job_name - - def get_summary_writer( - self, - name="DeepSpeedJobName", - base=os.path.join(os.path.expanduser("~"), - "tensorboard"), - ): - if self.tensorboard_output_path(): - base_dir = self.tensorboard_output_path() - job_name = self.tensorboard_job_name() - log_dir = os.path.join(base_dir, job_name) - else: - if self.tensorboard_job_name(): - name = self.tensorboard_job_name() - - # Infrastructure-specific job-id - if "DLWS_JOB_ID" in os.environ: - infra_job_id = os.environ["DLWS_JOB_ID"] - elif "DLTS_JOB_ID" in os.environ: - infra_job_id = os.environ["DLTS_JOB_ID"] - else: - infra_job_id = "unknown-job-id" - - summary_writer_dir_name = os.path.join(infra_job_id, "logs") - log_dir = os.path.join(base, summary_writer_dir_name, name) - - os.makedirs(log_dir, exist_ok=True) - try: - # torch.utils.tensorboard will fail if `tensorboard` is not available, - # see their docs for more details: https://pytorch.org/docs/1.8.0/tensorboard.html - import tensorboard - except ImportError: - print( - 'If you want to use tensorboard logging please `pip install tensorboard`' - ) - raise - from torch.utils.tensorboard import SummaryWriter - - return SummaryWriter(log_dir=log_dir) - def wall_clock_breakdown(self): return self._config.wall_clock_breakdown @@ -1713,7 +1665,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): loss = self._scale_loss_by_gas(loss.float()) # Log training Loss - if self.tensorboard_enabled(): + if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: self.summary_events = [( @@ -1721,9 +1673,7 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False): loss.mean().item() * self.gradient_accumulation_steps(), self.global_samples, )] - for event in self.summary_events: # write_summary_events - self.summary_writer.add_scalar(event[0], event[1], event[2]) - self.summary_writer.flush() + self.monitor.write_events(self.summary_events) self._start_timers(self.engine_timers.backward_timers) @@ -1946,14 +1896,13 @@ def step(self, lr_kwargs=None): self._stop_timers(self.engine_timers.step_timers) # Log learning rate - if self.tensorboard_enabled(): + if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: self.summary_events = [(f"Train/Samples/lr", self.get_lr()[0], self.global_samples)] - for event in self.summary_events: # write_summary_events - self.summary_writer.add_scalar(event[0], event[1], event[2]) + if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"): self.summary_events.append(( f"Train/Samples/loss_scale", @@ -1965,16 +1914,12 @@ def step(self, lr_kwargs=None): self.eigenvalue_gas_boundary_resolution()): ev_values = self.block_eigenvalue.values() for i in range(len(ev_values)): - self.summary_writer.add_scalar( + self.summary_events.append(( f"Train/Eigenvalues/ModelBlockParam_{i}", self.ev_values[i][0], self.global_samples, - ) - self.summary_writer.flush() - - for event in self.summary_events: # write_summary_events - self.summary_writer.add_scalar(event[0], event[1], event[2]) - self.summary_writer.flush() + )) + self.monitor.write_events(self.summary_events) # Check flops profiling if flops_profiler_active: @@ -2002,8 +1947,8 @@ def step(self, lr_kwargs=None): if self.wall_clock_breakdown() or self.flops_profiler_enabled(): # Log global timing and reset if self.is_gradient_accumulation_boundary(): - if self.tensorboard_enabled(): - self._write_tensorboard() + if self.monitor.enabled: + self._write_monitor() if self.has_moe_layers: fwd_time = self.timers(FORWARD_GLOBAL_TIMER).elapsed( @@ -2046,7 +1991,7 @@ def _autotuning_exit(self): atexit.register(print, "Autotuning: done with running current ds config.") exit() - def _write_tensorboard(self): + def _write_monitor(self): if self.global_rank == 0: self.summary_events = [ ( @@ -2077,9 +2022,7 @@ def _write_tensorboard(self): self.global_samples, ), ] - for event in self.summary_events: # write_summary_events - self.summary_writer.add_scalar(event[0], event[1], event[2]) - self.summary_writer.flush() + self.monitor.write_events(self.summary_events) def _get_optimizer_param(self, param_name): result = [] diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 2ea05d183ab1..f3b5344c5f69 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -365,16 +365,12 @@ def train_batch(self, data_iter=None): f'iter time (s): {iter_time:0.3f} ' f'samples/sec: {tput:0.3f}') - # Tensorboard - if self.tensorboard_enabled(): - if self.global_rank == 0: - self.summary_events = [(f'Train/Samples/train_loss', - self.agg_train_loss.mean().item(), - self.global_samples)] - for event in self.summary_events: # write_summary_events - self.summary_writer.add_scalar(event[0], event[1], event[2]) - if self.global_steps % self.steps_per_print() == 0: - self.summary_writer.flush() + # Monitoring + if self.global_rank == 0 and self.monitor.enabled: + self.summary_events = [(f'Train/Samples/train_loss', + self.agg_train_loss.mean().item(), + self.global_samples)] + self.monitor.write_events(self.summary_events) if self.wall_clock_breakdown( ) and self.global_steps % self.steps_per_print() == 0: @@ -458,14 +454,11 @@ def eval_batch(self, if compute_loss: eval_output = self._bcast_pipe_scalar(eval_output) - if self.tensorboard_enabled(): - if self.global_rank == 0: - self.summary_events = [(f'Train/Samples/eval_loss', - eval_output.mean().item(), - self.global_samples)] - for event in self.summary_events: # write_summary_events - self.summary_writer.add_scalar(event[0], event[1], event[2]) - self.summary_writer.flush() + if self.global_rank == 0 and self.monitor.enabled: + self.summary_events = [(f'Train/Samples/eval_loss', + eval_output.mean().item(), + self.global_samples)] + self.monitor.write_events(self.summary_events) # Restore the training iterator self.set_dataiterator(train_iterator) @@ -1171,17 +1164,15 @@ def _exec_optimizer_step(self, lr_kwargs=None): self.mem_status('AFTER STEP') - if self.tensorboard_enabled(): - if self.global_rank == 0: - self.summary_events = [(f'Train/Samples/lr', - self.get_lr()[0], - self.global_samples)] - if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'): - self.summary_events.append((f'Train/Samples/loss_scale', - self.optimizer.cur_scale, - self.global_samples)) - for event in self.summary_events: # write_summary_events - self.summary_writer.add_scalar(event[0], event[1], event[2]) + if self.global_rank == 0 and self.monitor.enabled: + self.summary_events = [(f'Train/Samples/lr', + self.get_lr()[0], + self.global_samples)] + if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'): + self.summary_events.append((f'Train/Samples/loss_scale', + self.optimizer.cur_scale, + self.global_samples)) + self.monitor.write_events(self.summary_events) if self.wall_clock_breakdown(): self.timers('step_microstep').stop() diff --git a/docs/_config.yml b/docs/_config.yml index dc79fc033b1a..456b16ff1d16 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -48,6 +48,7 @@ collections: - mixture-of-experts.md - mixture-of-experts-nlg.md - mixture-of-experts-inference.md + - monitor.md - one-cycle.md - onebit-adam.md - zero-one-adam.md diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 20f00b66760f..babcb8da2283 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -61,8 +61,8 @@ lnav: url: /docs/config-json/#activation-checkpointing - title: 'Sparse Attention' url: /docs/config-json/#sparse-attention - - title: 'Logging to TensorBoard' - url: /docs/config-json/#tensorboard-options + - title: 'Monitoring' + url: /docs/config-json/#monitoring-module-tensorboard-wandb-csv - title: 'Tutorials' url: /tutorials/ children: @@ -100,6 +100,8 @@ lnav: url: /tutorials/mixture-of-experts-inference/ - title: 'Mixture-of-Quantization' url: /tutorials/MoQ-tutorial/ + - title: 'Monitoring' + url: /tutorials/monitor - title: 'One-Cycle Schedule' url: /tutorials/one-cycle/ - title: 'One-Bit Adam' diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 53df586ec3e6..3b283b459b18 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -964,13 +964,15 @@ Configuring the asynchronous I/O module for offloading parameter and optimizer s | ---------------------------------------------------------------------------------------------------------------------------- | ------- | | List of which step to change difficulty level. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A | -### Logging to Tensorboard +### Monitoring Module (TensorBoard, WandB, CSV) **Note:** Deepspeed logs to TensorBoard through PyTorch. Logging to TensorBoard requires that the `tensorboard` package is installed (read more in the [PyTorch documentation](https://pytorch.org/docs/1.8.0/tensorboard.html)). {: .notice--warning} +**Note:** Logging to WandB requires that the `wandb` package is installed (read more in the [WandB documentation](https://docs.wandb.ai/quickstart)). +{: .notice--warning} -Deepspeed can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file. Below is an overview of what deepspeed will log. +Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), or to simple CSV files. Below is an overview of what DeepSpeed will log automatically. | Field | Description |Conditions | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | @@ -989,11 +991,11 @@ Deepspeed can log training details into a [Tensorboard](https://www.tensorflow.o | Fields | Value |Default | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | | enabled | Whether logging to [Tensorboard](https://www.tensorflow.org/tensorboard) is enabled. | `false` | -| job_name | Name for the current job. This will become a new directory inside `output_path` | `"DeepSpeedJobName"` | -| output_path | Path to where the Tensorboard logs will be written. | `~/tensorboard/` | +| output_path | Path to where the Tensorboard logs will be written. If None, the output path is set under the training script's launching path. | `null` | +| job_name | Name for the current job. This will become a new directory inside `output_path`. | `"DeepSpeedJobName"` | -Example of ** tensorboard** configuration: +Example of **tensorboard** configuration: ```json "tensorboard": { @@ -1002,3 +1004,43 @@ Example of ** tensorboard** configuration: "job_name": "train_bert" } ``` + +**wandb**: [dictionary] + +| Fields | Value |Default | +| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | +| enabled | Whether logging to [WandB](https://wandb.ai/site) is enabled. | `false` | +| group | Name for the WandB group. This can be used to group together runs. | `None` | +| team | Name for the WandB team. | `None` | +| project | Name for the WandB project. | `deepspeed` | + + +Example of **wandb** configuration: + +```json +"wandb": { + "enabled": true, + "group": "my_group", + "team": "my_team", + "project": "my_project" +} +``` + +**csv_monitor**: [dictionary] + +| Fields | Value |Default | +| ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | +| enabled | Whether logging to local CSV files is enabled. | `false` | +| output_path | Path to where the csv files will be written. If None, the output path is set under the training script's launching path. | `null` | +| job_name | Name for the current job. This will become a new directory inside `output_path` | `"DeepSpeedJobName"` | + + +Example of **csv_monitor** configuration: + +```json +"csv_monitor": { + "enabled": true, + "output_path": "output/ds_logs/", + "job_name": "train_bert" +} +``` diff --git a/docs/_pages/features.md b/docs/_pages/features.md index 4410f2b10268..c2da91340bda 100755 --- a/docs/_pages/features.md +++ b/docs/_pages/features.md @@ -322,6 +322,33 @@ The DeepSpeed Autotuner uses model information, system information, and heurist ``` The flops profiler can also be used as a standalone package. Please refer to the [Flops Profiler](/tutorials/flops-profiler) tutorial for more details. +### Monitor + +The DeepSpeed Monitor logs live training metrics to one or more monitoring backends, including PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), or simply to CSV files. The Monitor can be configured with one or more backends in the `deepspeed_config` file as follows: + +```json +{ + "tensorboard": { + "enabled": true, + "output_path": "output/ds_logs/", + "job_name": "train_bert" + } + "wandb": { + "enabled": true, + "team": "my_team", + "group": "my_group", + "project": "my_project" + } + "csv_monitor": { + "enabled": true, + "output_path": "output/ds_logs/", + "job_name": "train_bert" + } +} + +``` + +The Monitor can also be added to log custom metrics and client codes. Please refer to the [Monitor](/tutorials/monitor) tutorial for more details. ## Sparse Attention DeepSpeed offers sparse attention to support long sequences. Please refer to the [Sparse Attention](/tutorials/sparse-attention/) tutorial. diff --git a/docs/_tutorials/monitor.md b/docs/_tutorials/monitor.md new file mode 100644 index 000000000000..a9c111f8eeec --- /dev/null +++ b/docs/_tutorials/monitor.md @@ -0,0 +1,105 @@ +--- +title: "Monitor" +excerpt: "Monitor your model's training metrics live and log for future analysis" +tags: profiling performance-tuning +--- + +In this tutorial, we introduce the DeepSpeed Monitor and provide examples of its usage. + + - [Overview](#overview) + - [Usage](#usage) + +## Overview + +Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), and simple CSV files. + +Below is a live monitoring view for TensorBoard: + +![TensorBoard Example Output](/assets/images/tensorboard_monitor.PNG){: .align-center} + +Below is a live monitoring view for WandB: + +![WandB Example Output](/assets/images/wandb_monitor.PNG){: .align-center} + +## Usage + +The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics. + + - [Automatic Monitoring](#automatic-monitoring) + - [Custom Monitoring](#custom-monitoring) + +### Automatic Monitoring + +When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module-tensorboard-wandb-csv) for details. + +```json +{ + "tensorboard": { + "enabled": true, + "output_path": "output/ds_logs/", + "job_name": "train_bert" + } + "wandb": { + "enabled": true, + "team": "my_team", + "group": "my_group", + "project": "my_project" + } + "csv_monitor": { + "enabled": true, + "output_path": "output/ds_logs/", + "job_name": "train_bert" + } +} +``` + +DeepSpeed will automatically log to all available and enabled monitoring backends listed in the config, and will generate live monitoring views such as those listed above. + +### Custom Monitoring + +In addition to automatic monitoring, users can log their own custom metrics in client scripts. Currently, there are two ways to initialize Monitor objects: + +1. (Recommended) - Create a `MonitorMaster(ds_config.monitor_config)` object, which automatically initializes all monitor backends present in the DeepSpeed configuration +2. Create a specific `TensorBoardMonitor(ds_config.monitor_config)`, `WandbMonitor(ds_config.monitor_config)`, `csvMonitor(ds_config.monitor_config)` object which will only initialize a specific monitor backend present in the DeepSpeed configuration + + +The steps to create a custom monitor are as follows: + +1. Add import to your desired Monitor +2. Initialize monitor with DeepSpeed config's `monitor_config` +3. Create a list of one or more 3-tuples in the format `[("label", value, ds_engine.global_samples), ...]`\* +4. Call `monitor.write_events` on the list from step 3 + +\* Note - Some Monitor backends don't support mixed sample values. Be sure to use your DeepSpeed engine object's `global_samples` attribute in each 3-tuple + +For example usage, see the following modified [DeepSpeedExamples/cifar](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) example: + +```python +# Step 1: Import monitor (and DeepSpeed config, if needed) +from deepspeed.monitor.monitor import MonitorMaster +from deepspeed.runtime.config import DeepSpeedConfig + +# Step 2: Initialized monitor with DeepSpeed config (get DeepSpeed config object, if needed) +ds_config = DeepSpeedConfig("ds_config.json") +monitor = MonitorMaster(ds_config.monitor_config) + +for epoch in range(2): + + running_loss = 0.0 + for i, data in enumerate(trainloader): + pre = time.time() + inputs, labels = data[0].to(model_engine.local_rank), data[1].to( + model_engine.local_rank) + if fp16: + inputs = inputs.half() + outputs = model_engine(inputs) + loss = criterion(outputs, labels) + + model_engine.backward(loss) + model_engine.step() + post = time.time() + # Step 3: Create list of 3-tuple records (single entry in this case) + events = [("Time per step", post-pre, model_engine.global_samples)] + # Step 4: Call monitor.write_events on the list from step 3 + monitor.write_events(events) +``` diff --git a/docs/assets/images/tensorboard_monitor.PNG b/docs/assets/images/tensorboard_monitor.PNG new file mode 100644 index 000000000000..b62d96c335b1 Binary files /dev/null and b/docs/assets/images/tensorboard_monitor.PNG differ diff --git a/docs/assets/images/wandb_monitor.PNG b/docs/assets/images/wandb_monitor.PNG new file mode 100644 index 000000000000..f65aa6c5cda8 Binary files /dev/null and b/docs/assets/images/wandb_monitor.PNG differ diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 895e252a454f..e40a19b622fc 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,5 +4,6 @@ numpy packaging psutil py-cpuinfo +pydantic torch tqdm diff --git a/tests/unit/test_monitor.py b/tests/unit/test_monitor.py new file mode 100644 index 000000000000..4a84990d28fe --- /dev/null +++ b/tests/unit/test_monitor.py @@ -0,0 +1,139 @@ +import pytest + +from deepspeed.monitor.constants import * + +from deepspeed.monitor.monitor import MonitorMaster +from deepspeed.monitor.tensorboard import TensorBoardMonitor +from deepspeed.monitor.wandb import WandbMonitor +from deepspeed.monitor.csv_monitor import csvMonitor + +from .simple_model import * +from .common import distributed_test +from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.monitor.config import DeepSpeedMonitorConfig + +try: + import tensorboard + _tb_available = True +except ImportError: + _tb_available = False +tb_available = pytest.mark.skipif(not _tb_available, + reason="tensorboard is not installed") + +try: + import wandb + _wandb_available = True +except ImportError: + _wandb_available = False +wandb_available = pytest.mark.skipif(not _wandb_available, + reason="wandb is not installed") + + +@tb_available +def test_tensorboard(tmpdir): + @distributed_test(world_size=2) + def _test_tensorboard(): + config_dict = { + "train_batch_size": 2, + "tensorboard": { + "enabled": True, + "output_path": "test_output/ds_logs/", + "job_name": "test" + } + } + args = args_from_dict(tmpdir, config_dict) + ds_config = DeepSpeedConfig(args.deepspeed_config) + tb_monitor = TensorBoardMonitor(ds_config.monitor_config) + assert tb_monitor.enabled == True + assert tb_monitor.output_path == "test_output/ds_logs/" + assert tb_monitor.job_name == "test" + + _test_tensorboard() + + +@tb_available +def test_empty_tensorboard(tmpdir): + @distributed_test(world_size=2) + def _test_empty_tensorboard(): + config_dict = {"train_batch_size": 2, "tensorboard": {}} + args = args_from_dict(tmpdir, config_dict) + ds_config = DeepSpeedConfig(args.deepspeed_config) + tb_monitor = TensorBoardMonitor(ds_config.monitor_config) + assert tb_monitor.enabled == TENSORBOARD_ENABLED_DEFAULT + assert tb_monitor.output_path == TENSORBOARD_OUTPUT_PATH_DEFAULT + assert tb_monitor.job_name == TENSORBOARD_JOB_NAME_DEFAULT + + _test_empty_tensorboard() + + +@wandb_available +def test_wandb(tmpdir): + @distributed_test(world_size=2) + def _test_wandb(): + config_dict = { + "train_batch_size": 2, + "wandb": { + "enabled": False, + "group": "my_group", + "team": "my_team", + "project": "my_project" + } + } + args = args_from_dict(tmpdir, config_dict) + ds_config = DeepSpeedConfig(args.deepspeed_config) + wandb_monitor = WandbMonitor(ds_config.monitor_config) + assert wandb_monitor.enabled == False + assert wandb_monitor.group == "my_group" + assert wandb_monitor.team == "my_team" + assert wandb_monitor.project == "my_project" + + _test_wandb() + + +@wandb_available +def test_empty_wandb(tmpdir): + @distributed_test(world_size=2) + def _test_empty_wandb(): + config_dict = {"train_batch_size": 2, "wandb": {}} + args = args_from_dict(tmpdir, config_dict) + ds_config = DeepSpeedConfig(args.deepspeed_config) + wandb_monitor = WandbMonitor(ds_config.monitor_config) + assert wandb_monitor.enabled == WANDB_ENABLED_DEFAULT + assert wandb_monitor.group == WANDB_GROUP_NAME_DEFAULT + assert wandb_monitor.team == WANDB_TEAM_NAME_DEFAULT + assert wandb_monitor.project == WANDB_PROJECT_NAME_DEFAULT + + _test_empty_wandb() + + +def test_csv_monitor(tmpdir): + @distributed_test(world_size=2) + def _test_csv_monitor(): + config_dict = { + "train_batch_size": 2, + "csv_monitor": { + "enabled": True, + "output_path": "test_output/ds_logs/", + "job_name": "test" + } + } + args = args_from_dict(tmpdir, config_dict) + ds_config = DeepSpeedConfig(args.deepspeed_config) + csv_monitor = csvMonitor(ds_config.monitor_config) + assert csv_monitor.enabled == True + assert csv_monitor.output_path == "test_output/ds_logs/" + assert csv_monitor.job_name == "test" + + _test_csv_monitor() + + +def test_empty_csv_monitor(tmpdir): + @distributed_test(world_size=2) + def _test_empty_csv_monitor(): + config_dict = {"train_batch_size": 2, "csv_monitor": {}} + args = args_from_dict(tmpdir, config_dict) + ds_config = DeepSpeedConfig(args.deepspeed_config) + csv_monitor = csvMonitor(ds_config.monitor_config) + assert csv_monitor.enabled == CSV_MONITOR_ENABLED_DEFAULT + assert csv_monitor.output_path == CSV_MONITOR_OUTPUT_PATH_DEFAULT + assert csv_monitor.job_name == CSV_MONITOR_JOB_NAME_DEFAULT