Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py
index 1846907e9..7fd0554c4 100644
index 1846907e9..8355608fe 100644
--- a/megatron/core/optimizer/__init__.py
+++ b/megatron/core/optimizer/__init__.py
@@ -55,6 +55,7 @@ def _get_param_groups(
@@ -3,6 +3,7 @@ import logging
import warnings
from typing import Callable, Dict, List, Optional, Tuple

+import os
import torch
from torch.optim import SGD as CPUSGD
from torch.optim import AdamW as CPUAdam
@@ -55,6 +56,7 @@ def _get_param_groups(
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
default_skip_embedding_weight_decay: bool = False,
+ vision_ration = 1.0,
) -> List[Dict]:
"""Create parameter groups for optimizer.

@@ -106,6 +107,8 @@ def _get_param_groups(
@@ -106,6 +108,8 @@ def _get_param_groups(
or len(param.shape) == 1
or (default_skip_embedding_weight_decay and "embedding" in name)
)
Expand All @@ -19,73 +27,135 @@ index 1846907e9..7fd0554c4 100644

if scale_lr_cond is not None:
scale_lr = scale_lr_cond(name, param)
@@ -128,8 +131,14 @@ def _get_param_groups(
param, 'is_embedding_or_output_parameter', False
@@ -129,37 +133,64 @@ def _get_param_groups(
):
is_decoupled_lr = True
+
+ is_vision_model_param = False
+ if "vision_model" in name:
+ is_vision_model_param = True
+ else:
+ is_vision_model_param = False

- key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
+ key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)
@@ -147,7 +156,7 @@ def _get_param_groups(
- if key not in params_map:
- params_map[key] = []
- params_map[key].append(param)
+ if os.environ.get("ENABLE_SIMULATOR") == "1":
+ key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
+ if key not in params_map:
+ params_map[key] = []
+ params_map[key].append(param)
+ else:
+ is_vision_model_param = False
+ if "vision_model" in name:
+ is_vision_model_param = True
+ else:
+ is_vision_model_param = False
+
+ key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param)
+ if key not in params_map:
+ params_map[key] = []
+ params_map[key].append(param)
+ if os.environ.get("ENABLE_SIMULATOR") == "1":
+ param_groups = []
+ for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
+ assert len(params) > 0
+ param_group = {
+ 'params': params,
+ 'wd_mult': wd_mult,
+ 'lr_mult': _lr_mult,
+ 'is_expert_parallel': is_expert_parallel,
+ 'is_decoupled_lr': is_decoupled_lr,
+ }
+ param_groups.append(param_group)

param_groups = []
for key in params_key:
# Distributed checkpoint requires all ranks to have the same param groups,
# so we need to align the param groups across ranks, otherwise we may have
# runtime error when loading the checkpoint or numerical error when resuming training.
- params_key = list(params_map.keys())
- gathered_params_key = [None for _ in range(torch.distributed.get_world_size())]
- torch.distributed.all_gather_object(gathered_params_key, params_key)
- for keys in gathered_params_key:
- for key in keys:
- if key not in params_key:
- params_key.append(key)
-
- param_groups = []
- for key in params_key:
- wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr = key
+ wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param = key
params = params_map[key] if key in params_map else []
param_group = {
'params': params,
@@ -155,6 +164,7 @@ def _get_param_groups(
'lr_mult': _lr_mult,
'is_expert_parallel': is_expert_parallel,
'is_decoupled_lr': is_decoupled_lr,
+ 'is_vision_model_param': is_vision_model_param,
}
# Ensure param_group has required keys for matching when loading optimizer state
# See MegatronOptimizer._filter_and_reorder_param_groups.
@@ -167,6 +177,7 @@ def _get_param_groups(
- params = params_map[key] if key in params_map else []
- param_group = {
- 'params': params,
- 'wd_mult': wd_mult,
- 'lr_mult': _lr_mult,
- 'is_expert_parallel': is_expert_parallel,
- 'is_decoupled_lr': is_decoupled_lr,
- }
- # Ensure param_group has required keys for matching when loading optimizer state
- # See MegatronOptimizer._filter_and_reorder_param_groups.
- assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'}
- param_groups.append(param_group)
+ else:
+ params_key = list(params_map.keys())
+ gathered_params_key = [None for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather_object(gathered_params_key, params_key)
+ for keys in gathered_params_key:
+ for key in keys:
+ if key not in params_key:
+ params_key.append(key)
+
+ param_groups = []
+ for key in params_key:
+ wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr, is_vision_model_param = key
+ params = params_map[key] if key in params_map else []
+ param_group = {
+ 'params': params,
+ 'wd_mult': wd_mult,
+ 'lr_mult': _lr_mult,
+ 'is_expert_parallel': is_expert_parallel,
+ 'is_decoupled_lr': is_decoupled_lr,
+ 'is_vision_model_param': is_vision_model_param,
+ }
+ # Ensure param_group has required keys for matching when loading optimizer state
+ # See MegatronOptimizer._filter_and_reorder_param_groups.
+ assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'}
+ param_groups.append(param_group)
+

param_groups = _update_min_and_max_lr_in_param_groups(
param_groups,
@@ -167,6 +198,7 @@ def _get_param_groups(
min_lr=min_lr,
decoupled_lr=decoupled_lr,
decoupled_min_lr=decoupled_min_lr,
+ vision_ration=vision_ration,
)

return param_groups
@@ -178,6 +189,7 @@ def _update_min_and_max_lr_in_param_groups(
@@ -178,6 +210,7 @@ def _update_min_and_max_lr_in_param_groups(
min_lr: float,
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
+ vision_ration = 0.1,
) -> List[Dict]:
"""
Updates `max_lr` and `min_lr` values in each parameter group, and returns new list.
@@ -206,7 +218,7 @@ def _update_min_and_max_lr_in_param_groups(
@@ -206,7 +239,10 @@ def _update_min_and_max_lr_in_param_groups(
param_group['max_lr'] = decoupled_lr
param_group['min_lr'] = decoupled_min_lr
else:
- param_group['max_lr'] = lr
+ param_group['max_lr'] = lr if not param_group['is_vision_model_param'] else lr * vision_ration # NOTE(lizhiyu): change the ration here
+ if os.environ.get("ENABLE_SIMULATOR") == "1":
+ param_group['max_lr'] = lr
+ else:
+ param_group['max_lr'] = lr if not param_group['is_vision_model_param'] else lr * vision_ration # NOTE(lizhiyu): change the ration here
param_group['min_lr'] = min_lr
return param_groups

@@ -255,6 +267,7 @@ def _get_param_groups_and_buffers(
@@ -255,6 +291,7 @@ def _get_param_groups_and_buffers(
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
default_skip_embedding_weight_decay=default_skip_embedding_weight_decay,
+ vision_ration=config.vision_ration, # NOTE(lizhiyu): The vision ration is used to scale the learning rate for vision model parameters. Added by FlagScale.
)
param_groups = list(filter(filter_fn, param_groups))
buffers = {}
@@ -511,6 +524,10 @@ def get_megatron_optimizer(
@@ -511,6 +548,10 @@ def get_megatron_optimizer(
intra_dp_cp_group = process_groups['intra_dp_cp_group']
intra_expt_dp_group = process_groups['intra_expt_dp_group']
mp_group = process_groups['mp_group']
Expand All @@ -96,7 +166,7 @@ index 1846907e9..7fd0554c4 100644
expt_tp_pp_group = process_groups['expt_tp_pp_group']
intra_dp_cp_group_gloo = process_groups['intra_dp_cp_group_gloo']
intra_expt_dp_group_gloo = process_groups['intra_expt_dp_group_gloo']
@@ -609,7 +626,11 @@ def get_megatron_optimizer(
@@ -609,7 +650,11 @@ def get_megatron_optimizer(
default_skip_embedding_weight_decay=default_skip_embedding_weight_decay,
)
if len(moe_param_groups) > 0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index 1120c7529..190fac52b 100644
index 1120c7529..ebb1467c3 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -67,6 +67,7 @@ def add_megatron_arguments(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -477,7 +477,7 @@ index 1120c7529..190fac52b 100644
dest='overlap_p2p_comm_warmup_flush')
group.add_argument('--distributed-backend', default='nccl',
- choices=['nccl', 'gloo'],
+ choices=['nccl', 'gloo', 'flagcx'],
+ choices=['nccl', 'gloo', 'flagcx', 'dummy'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
Expand Down Expand Up @@ -602,7 +602,7 @@ index 1120c7529..190fac52b 100644

return parser

@@ -3275,3 +3542,75 @@ def _add_sft_args(parser):
@@ -3275,3 +3542,78 @@ def _add_sft_args(parser):
group.add_argument('--sft-tokenizer-prompt-format', type=str, default="nemotron-h-aligned",
help='SFT prompt format.')
return parser
Expand Down Expand Up @@ -637,6 +637,9 @@ index 1120c7529..190fac52b 100644
+ group.add_argument('--auto-tune', action='store_true',
+ help='use auto tuner')
+
+ group.add_argument('--enable-simulator', action='store_true',
+ help='Use single process to simulate the distributed training.')
+
+ return parser
+
+
Expand Down
45 changes: 45 additions & 0 deletions flagscale/runner/auto_tuner/simulator/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Environment
Begin at the root path of `FlagScale` repository:
1. Install backend
```
cd flagscale/runner/auto_tuner/simulator/custom_backend/
python setup.py develop
```

# Setup
2. Set necessary parameters in `config_gen.py`. For example:
```
device_type_list = ["A", "B"]
device_num_list = [4, 4]
global_batch_size = 32
num_micro_batches = 8
num_layers = 4
```
# Run a Task
3. Start the auto-tuning:
a. set PYTHONPATH
```
export PYTHONPATH=/***/FlagScale:$PYTHONPATH
export PYTHONPATH=$PYTHONPATH:/***/FlagScale/third_party/Megatron-LM

vim /***/FlagScale/flagscale/runner/auto_tuner/simulator/analylize_pipeline_time.py
os.environ["PYTHONPATH"] = (
"/***/FlagScale:"
"/***/FlagScale/third_party/Megatron-LM"
)
```
b. run

vim flagscale/runner/auto_tuner/simulator/config_gen.py

set scheme = vpp or scheme = 1F1B

python flagscale/runner/auto_tuner/simulator/config_gen.py

c. result
```
{'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 8, 5, 5, 5, 1], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 57.52105478485333, 'theory_peak_memory': [110.487650304, 118.80914944, 158.35625472, 158.35625472, 158.35625472, 42.519842816], 'oom_error': True}
{'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 7, 5, 5, 5, 2], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 61.20105478485332, 'theory_peak_memory': [110.487650304, 109.345202176, 158.35625472, 158.35625472, 158.35625472, 61.447737344], 'oom_error': True}
{'mesh': [2, 1, 1, 1, 2, 1, 1, 1, 1, 4], 'device_types': ['A800', 'A800'], 'pp_layer_split': [8, 8, 5, 5, 4, 2], 'recompute_granularity': None, 'recompute_method': 'uniform', 'recompute_num_layers': 1, 'simulated_time': 54.73105478485331, 'theory_peak_memory': [110.487650304, 118.80914944, 158.35625472, 158.35625472, 119.365943296, 61.447737344], 'oom_error': True}
...
```
Loading
Loading