Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VP Fixes for converter + Config management #6698

Merged
merged 2 commits into from
May 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 83 additions & 16 deletions examples/nlp/language_modeling/megatron_change_num_partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
--target_pipeline_model_parallel_size=1 \
--target_pipeline_model_parallel_split_rank=0 \
--precision=bf16

# Megatron GPT + Virtual Pipeline parallelism

python megatron_change_num_partitions.py \
Expand Down Expand Up @@ -138,17 +138,34 @@ def set_virtual_parallel_rank_safely(rank: int):

def force_cpu_model(cfg):
with open_dict(cfg):
# temporarily
# temporarily set to cpu
original_cpu_init = cfg.get('use_cpu_initialization', False)
original_amp_o2 = cfg.get('megatron_amp_O2', False)
if 'megatron_amp_O2' in cfg:
key = 'megatron_amp_O2'
original_amp_o2 = cfg.megatron_amp_O2
elif 'megatron_amp_02' in cfg:
key = 'megatron_amp_02'
original_amp_o2 = cfg.megatron_amp_02
else:
key, original_amp_o2 = None, None

# Set new values
cfg.use_cpu_initialization = True
cfg.megatron_amp_O2 = False
return cfg, {'original_cpu_init': original_cpu_init, 'original_amp_o2': original_amp_o2}
if key is not None:
cfg[key] = False

# Setup restore dict
restore_dict = {'use_cpu_initialization': original_cpu_init} # 'megatron_amp_O2': original_amp_o2
if key is not None:
restore_dict[key] = original_amp_o2

return cfg, restore_dict


def restore_model_config(cfg, original_dict):
with open_dict(cfg):
for key, val in original_dict.items():
logging.info(f"Restoring model config key ({key}) from {cfg[key]} to original value of {val}")
cfg[key] = val
return cfg

Expand Down Expand Up @@ -1034,6 +1051,8 @@ def main():
os.path.join(model_filepath, args.ckpt_name)
)

vp_state_dict = torch.load(checkpoint_path, map_location="cpu")

if hparams_filepath is not None:
# Force the model onto CPU
tmp_cfg = OmegaConf.load(hparams_filepath)
Expand Down Expand Up @@ -1078,9 +1097,10 @@ def main():
vp_params_tmp = []
for vp_idx in range(vp_size):
set_virtual_parallel_rank_safely(vp_idx)
params = [p for p in model.model[vp_idx].parameters()]
# params = model.model[vp_idx].module.state_dict_for_save_checkpoint()
# params = [p for p in params.values()]
vp_params = vp_state_dict[f'model{vp_idx}']
model.model[vp_idx].module.load_state_dict(vp_params, strict=True)
model.model[vp_idx].module.to('cpu')
params = [p for p in model.model[vp_idx].module.parameters()]
vp_params_tmp.append(params)
# partitions[pp_rank][vp_idx].append(params)

Expand Down Expand Up @@ -1141,6 +1161,8 @@ def main():
model = model.to('cpu')
model._save_restore_connector = NLPSaveRestoreConnector()

restore_model_config(model.cfg, restore_dict)

vp_param_count = 0
for vp in range(vp_size):
for pp in range(pp_size):
Expand All @@ -1159,24 +1181,69 @@ def main():
else:
flat_partitions = {idx: [] for idx in range(pp_size)}

for pp in range(pp_size):
for tp in range(tp_size):
vp_cache = []
for vp in range(vp_size):
vp_cache.extend(partitions[vp][pp][tp])
"""
Under VP convention
Notation :
Stage = PP rank
Number = GPT model / layer index
Ignore TP - every PP has all TP corresponding to that PP
chunk_index = the physical index of any [] in the list. Ex idx = 2 in below map corresponds to [2: PP 0 VP 1]]


For a PP 2 VP 4 model with 8 GPT layers-

flat_partitions[pp].append(vp_cache)
Indices
# Stage 0: [0:PP 0 VP 0] [2:PP 0 VP 1] [4:PP 0 VP 2] [6:PP 0 VP 3]
# Stage 1: [1:PP 1 VP 0] [3:PP 1 VP 1] [5:PP 1 VP 2] [7:PP 1 VP 3]

after conversion will become

# Stage 0: [0,1,2,3:PP 0]
# Stage 1: [4,5,6,7:PP 1]

"""
pp_index = 0
chunk_counter = 0
tp_cache = [[] for _ in range(tp_size)]

for vp in range(vp_size):
for pp in range(pp_size):
# Gather all TP under this VP PP combination.
# We will accumulate TP parameters from multiple layers in this cache.
for tp in range(tp_size):
tp_cache[tp].extend(partitions[vp][pp][tp])

# This counter indexes the global selection of a VP PP combination in the above map
chunk_counter += 1

# Log the mapping from old VP x PP to new PP index
logging.info(f"VP Conversion - vp: {vp} pp: {pp} -> pp_idx: {pp_index}")

# Every vp_size chunks, we can fill a new PP index in the flat_partitions
if chunk_counter % vp_size == 0:
flat_partitions[pp_index].extend(tp_cache)
tp_cache = [[] for _ in range(tp_size)]
pp_index += 1

logging.debug(
f"VP merge step: \n"
f"vp: {vp} pp: {pp} pp_idx: {pp_index - 1} "
f"len(flat_partitions): {len(flat_partitions[pp_index - 1])}"
)

logging.debug(f"PP Size len(flat partitions) : {len(flat_partitions)}")
logging.debug(f"TP Size len(flat partitions[0]): {len(flat_partitions[0])}")
logging.debug(f"Layers len(flat partitions[0][0]) : {len(flat_partitions[0][0])}")

partitions = flat_partitions
del tp_cache

if tgt_tp_size > 1 or tgt_pp_size > 1:
merge_partition(model, partitions)
else:
# Write out the PP 1 TP 1 model to disk
merge_partition(model, partitions, args.target_file)

restore_model_config(model.cfg, restore_dict)

# Empty cache memory of all parameters from all PP TP partitions
partitions.clear()

Expand Down