Skip to content

Commit

Permalink
Update support for deep speed and multiple improvements.
Browse files Browse the repository at this point in the history
  • Loading branch information
MGAMZ committed Jan 11, 2025
1 parent 2a5a1fe commit be86710
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 112 deletions.
10 changes: 8 additions & 2 deletions mmengine/_strategy/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def register_deepspeed_optimizers() -> List[str]:
@OPTIM_WRAPPERS.register_module()
class DeepSpeedOptimWrapper(BaseOptimWrapper):

def __init__(self, optimizer):
def __init__(self, optimizer, accumulative_counts):
super().__init__(optimizer)
self._model = None
self._inner_count = 0
self._accumulative_counts = accumulative_counts

@property
def model(self):
Expand All @@ -80,11 +82,13 @@ def model(self, value):
def update_params(self, loss) -> None: # type: ignore
"""Update parameters in :attr:`optimizer`."""
self.backward(loss)
self.step()
if self.should_update():
self.step()

def backward(self, loss: torch.Tensor, **kwargs) -> None:
""""Perform gradient back propagation."""
self.model.backward(loss)
self._inner_count += 1

def zero_grad(self, **kwargs) -> None:
raise NotImplementedError(
Expand All @@ -107,6 +111,8 @@ def load_state_dict(self, state_dict: dict) -> None:
if base_param_settings is not None:
self.base_param_settings = base_param_settings

def should_update(self) -> bool:
return (self._inner_count % self._accumulative_counts == 0)

@MODEL_WRAPPERS.register_module()
class MMDeepSpeedEngineWrapper:
Expand Down
218 changes: 110 additions & 108 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,120 +1375,122 @@ def env_variables(self) -> dict:
@property
def pretty_text(self) -> str:
"""Get formatted python config text."""
try:
indent = 4

def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s

def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = repr(v)
else:
v_str = str(v)

indent = 4

def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s

def _format_basic_types(k, v, use_mapping=False):
if isinstance(v, str):
v_str = repr(v)
else:
v_str = str(v)

if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent)

return attr_str
return attr_str

def _format_list_tuple(k, v, use_mapping=False):
if isinstance(v, list):
left = '['
right = ']'
else:
left = '('
right = ')'

v_str = f'{left}\n'
# check if all items in the list are dict
for item in v:
if isinstance(item, dict):
v_str += f'dict({_indent(_format_dict(item), indent)}),\n'
elif isinstance(item, tuple):
v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501
elif isinstance(item, list):
v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501
elif isinstance(item, str):
v_str += f'{_indent(repr(item), indent)},\n'
def _format_list_tuple(k, v, use_mapping=False):
if isinstance(v, list):
left = '['
right = ']'
else:
v_str += str(item) + ',\n'
if k is None:
return _indent(v_str, indent) + right
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent) + right
return attr_str

def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= \
(not str(key_name).isidentifier())
return contain_invalid_identifier

def _format_dict(input_dict, outest_level=False):
r = ''
s = []

use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += '{'
for idx, (k, v) in enumerate(
sorted(input_dict.items(), key=lambda x: str(x[0]))):
is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: dict({v_str}'
left = '('
right = ')'

v_str = f'{left}\n'
# check if all items in the list are dict
for item in v:
if isinstance(item, dict):
v_str += f'dict({_indent(_format_dict(item), indent)}),\n'
elif isinstance(item, tuple):
v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501
elif isinstance(item, list):
v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501
elif isinstance(item, str):
v_str += f'{_indent(repr(item), indent)},\n'
else:
attr_str = f'{str(k)}=dict({v_str}'
attr_str = _indent(attr_str, indent) + ')' + end
elif isinstance(v, (list, tuple)):
attr_str = _format_list_tuple(k, v, use_mapping) + end
v_str += str(item) + ',\n'
if k is None:
return _indent(v_str, indent) + right
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: {v_str}'
else:
attr_str = _format_basic_types(k, v, use_mapping) + end

s.append(attr_str)
r += '\n'.join(s)
if use_mapping:
r += '}'
return r

cfg_dict = self.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
if self._format_python_code:
# copied from setup.cfg
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
try:
if digit_version(yapf.__version__) >= digit_version('0.40.2'):
text, _ = FormatCode(text, style_config=yapf_style)
else:
text, _ = FormatCode(
text, style_config=yapf_style, verify=True)
except: # noqa: E722
raise SyntaxError('Failed to format the config file, please '
f'check the syntax of: \n{text}')
return text
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent) + right
return attr_str

def _contain_invalid_identifier(dict_str):
contain_invalid_identifier = False
for key_name in dict_str:
contain_invalid_identifier |= \
(not str(key_name).isidentifier())
return contain_invalid_identifier

def _format_dict(input_dict, outest_level=False):
r = ''
s = []

use_mapping = _contain_invalid_identifier(input_dict)
if use_mapping:
r += '{'
for idx, (k, v) in enumerate(
sorted(input_dict.items(), key=lambda x: str(x[0]))):
is_last = idx >= len(input_dict) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
if use_mapping:
k_str = f"'{k}'" if isinstance(k, str) else str(k)
attr_str = f'{k_str}: dict({v_str}'
else:
attr_str = f'{str(k)}=dict({v_str}'
attr_str = _indent(attr_str, indent) + ')' + end
elif isinstance(v, (list, tuple)):
attr_str = _format_list_tuple(k, v, use_mapping) + end
else:
attr_str = _format_basic_types(k, v, use_mapping) + end

s.append(attr_str)
r += '\n'.join(s)
if use_mapping:
r += '}'
return r

cfg_dict = self.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
if self._format_python_code:
# copied from setup.cfg
yapf_style = dict(
based_on_style='pep8',
blank_line_before_nested_class_or_def=True,
split_before_expression_after_opening_paren=True)
try:
if digit_version(yapf.__version__) >= digit_version('0.40.2'):
text, _ = FormatCode(text, style_config=yapf_style)
else:
text, _ = FormatCode(
text, style_config=yapf_style, verify=True)
except: # noqa: E722
raise SyntaxError('Failed to format the config file, please '
f'check the syntax of: \n{text}')
return text
except Exception as e:
return f'Error occurs when formatting config: {e}'

def __repr__(self):
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
Expand Down
6 changes: 4 additions & 2 deletions mmengine/logging/message_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,10 @@ def _get_valid_value(
else:
# check whether value is torch.Tensor but don't want
# to import torch in this file
assert hasattr(value, 'numel') and value.numel() == 1
value = value.item()
if hasattr(value, 'numel') and value.numel() == 1:
value = value.item()
else:
print_log(f"MessageHub got unexpceted log: {value}", level=logging.WARN)
return value # type: ignore

def state_dict(self) -> dict:
Expand Down
1 change: 1 addition & 0 deletions mmengine/model/averaged_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def update_parameters(self, model: nn.Module) -> None:
for k, p_avg in self.avg_parameters.items():
p_avg.data.copy_(src_parameters[k].data)
elif self.steps % self.interval == 0:
print(self.avg_parameters)
for k, p_avg in self.avg_parameters.items():
if p_avg.dtype.is_floating_point:
device = p_avg.device
Expand Down

0 comments on commit be86710

Please sign in to comment.