Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: adityavavre <[email protected]>
  • Loading branch information
adityavavre committed Aug 3, 2024
1 parent 19081eb commit 420cff5
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions examples/nlp/language_modeling/megatron_gpt_upcycling.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import gc
import re
from collections import OrderedDict

from copy import deepcopy
from typing import List, Union

import torch
import torch.multiprocessing as mp
from einops import rearrange, repeat
from megatron.core import parallel_state
from omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from typing import Union, List

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
Expand All @@ -27,7 +26,7 @@
def modify_config_for_upcycling(gpt_cfg: OmegaConf, cfg: OmegaConf) -> OmegaConf:
"""
Function that modifies the parallelism configuration of base model to the MoE parallelism config
so that the weights can be upcycled and loaded correctly on each rank. This makes sure that each
so that the weights can be upcycled and loaded correctly on each rank. This makes sure that each
layer and weight matrix is sharded in the same way across the ranks. Expert parallelism is ignored
while loading base model weights and is upcycled based on the EP size in the upcycling function.
Args:
Expand Down Expand Up @@ -82,7 +81,10 @@ def load_state_dict_from_nemo(
save_restore_connector=save_restore_connector,
)
if isinstance(instance.model, list):
state_dict = [deepcopy(instance.model[i].to(dtype=torch_dtype_from_precision(gpt_cfg.precision)).state_dict()) for i in range(len(instance.model))]
state_dict = [
deepcopy(instance.model[i].to(dtype=torch_dtype_from_precision(gpt_cfg.precision)).state_dict())
for i in range(len(instance.model))
]
else:
state_dict = deepcopy(instance.model.to(dtype=torch_dtype_from_precision(gpt_cfg.precision)).state_dict())
del instance
Expand All @@ -91,6 +93,7 @@ def load_state_dict_from_nemo(
torch.cuda.empty_cache()
return state_dict


def upcycle_weights_for_moe(cfg: OmegaConf, state_dict: OrderedDict) -> OrderedDict:
"""
Upcycle base model weights to MoE model weights that can be loaded into a MoE model instance.
Expand Down Expand Up @@ -146,7 +149,7 @@ def upcycle_weights_for_moe(cfg: OmegaConf, state_dict: OrderedDict) -> OrderedD
new_key = 'decoder.layers.' + m.group(1) + '.mlp.router.weight'
# Create a router for each fc1
print('creating new router', new_key, 'layer', m.group(1))
router = torch.nn.Linear(v.size(1), cfg.model.num_moe_experts)
router = torch.nn.Linear(v.size(1), cfg.model.num_moe_experts)
# low init value helps upcycling
if router_std > 0:
torch.nn.init.normal_(router.weight, mean=0.0, std=router_std)
Expand Down Expand Up @@ -208,7 +211,7 @@ def upcycle_weights_for_moe(cfg: OmegaConf, state_dict: OrderedDict) -> OrderedD
t += expert_std * torch.randn_like(t)
new_key_values.append((new_key, t))
else:
new_key_values.append((new_key, v.detach().clone()))
new_key_values.append((new_key, v.detach().clone()))
else:
new_key = 'decoder.layers.' + m.group(1) + '.mlp.experts.weight2'
if transformer_impl == 'scattermoe':
Expand All @@ -219,7 +222,7 @@ def upcycle_weights_for_moe(cfg: OmegaConf, state_dict: OrderedDict) -> OrderedD
print(w2.shape)
new_key_values.append((new_key, w2))
else:
w2 = scale_st * v.detach().clone().t()
w2 = scale_st * v.detach().clone().t()
w2 = repeat(w2, 'f h -> e f h', e=num_moe_experts // granularity)
w2 = rearrange(w2, 'e (f g) h -> (e g) f h', g=granularity).contiguous()
new_key_values.append((new_key, w2.reshape(-1, v.shape[0]).contiguous()))
Expand Down Expand Up @@ -250,7 +253,7 @@ def main(cfg) -> None:

trainer = MegatronTrainerBuilder(cfg).create_trainer()
save_restore_connector = NLPSaveRestoreConnector()

# load base model state dict
state_dict = load_state_dict_from_nemo(
MegatronGPTModel, cfg, save_restore_connector=save_restore_connector, trainer=trainer
Expand All @@ -269,7 +272,7 @@ def main(cfg) -> None:
else:
state_dict = upcycle_weights_for_moe(cfg=cfg, state_dict=state_dict)
state_dict = save_restore_connector.modify_state_dict(cfg, state_dict=state_dict)

# load new instance with upcycled weights
if isinstance(model_instance.model, list):
if not isinstance(state_dict, list):
Expand All @@ -285,5 +288,6 @@ def main(cfg) -> None:
exp_manager(trainer, cfg.exp_manager)
trainer.fit(model_instance)


if __name__ == "__main__":
main()

0 comments on commit 420cff5

Please sign in to comment.