Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 26 additions & 34 deletions deepspeed/profiling/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,47 +4,39 @@
Licensed under the MIT license.
"""

from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject
from deepspeed.profiling.constants import *
from pydantic import Field
from deepspeed.runtime.config_utils import DeepSpeedConfigModel


class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigObject):
def __init__(self, param_dict):
super(DeepSpeedFlopsProfilerConfig, self).__init__()
def get_flops_profiler_config(param_dict):
flops_profiler_config_dict = param_dict.get("flops_profiler", {})
return DeepSpeedFlopsProfilerConfig(**flops_profiler_config_dict)

self.enabled = None
self.profile_step = None
self.module_depth = None
self.top_modules = None

if FLOPS_PROFILER in param_dict.keys():
flops_profiler_dict = param_dict[FLOPS_PROFILER]
else:
flops_profiler_dict = {}
class DeepSpeedFlopsProfilerConfig(DeepSpeedConfigModel):
""" Sets parameters for the flops profiler. """

self._initialize(flops_profiler_dict)
enabled: bool = False
""" Enables the flops profiler. This also enables wall_clock_breakdown. """

def _initialize(self, flops_profiler_dict):
self.enabled = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_ENABLED,
FLOPS_PROFILER_ENABLED_DEFAULT)
profile_step: int = Field(1, ge=1)
"""
The global training step at which to profile. Note that warm up steps are
needed for accurate time measurement.
"""

self.profile_step = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_PROFILE_STEP,
FLOPS_PROFILER_PROFILE_STEP_DEFAULT)
module_depth: int = -1
"""
The depth of the model at which to print the aggregated module information.
When set to `-1`, it prints information from the top module to the
innermost modules (the maximum depth).
"""

self.module_depth = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_MODULE_DEPTH,
FLOPS_PROFILER_MODULE_DEPTH_DEFAULT)
top_modules: int = 1
""" Limits the aggregated profile output to the number of top modules specified. """

self.top_modules = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_TOP_MODULES,
FLOPS_PROFILER_TOP_MODULES_DEFAULT)
detailed: bool = True
""" Whether to print the detailed model profile. """

self.detailed = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_DETAILED,
FLOPS_PROFILER_DETAILED_DEFAULT)

self.output_file = get_scalar_param(flops_profiler_dict,
FLOPS_PROFILER_OUTPUT_FILE,
FLOPS_PROFILER_OUTPUT_FILE_DEFAULT)
output_file: str = None
""" Path to the output file. If None, the profiler prints to stdout. """
44 changes: 0 additions & 44 deletions deepspeed/profiling/constants.py

This file was deleted.

4 changes: 2 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
NUM_GPUS_PER_NODE_DEFAULT,
)

from ..profiling.config import DeepSpeedFlopsProfilerConfig
from ..profiling.config import get_flops_profiler_config
from ..autotuning.config import DeepSpeedAutotuningConfig
from ..nebula.config import DeepSpeedNebulaConfig

Expand Down Expand Up @@ -862,7 +862,7 @@ def _initialize_params(self, param_dict):
self.scheduler_name = get_scheduler_name(param_dict)
self.scheduler_params = get_scheduler_params(param_dict)

self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict)
self.flops_profiler_config = get_flops_profiler_config(param_dict)
self.wall_clock_breakdown = (get_wall_clock_breakdown(param_dict)
| self.flops_profiler_config.enabled)
self.memory_breakdown = get_memory_breakdown(param_dict)
Expand Down