From e855e2504b0e4939b7e8bfefe7dcf16554df1ea1 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 24 Jan 2024 22:53:31 -0800 Subject: [PATCH 1/3] meta init - meta init working, reinflate not --- .gitignore | 1 + torchtrain/meta_init.py | 52 ++++++++++++++++++++ torchtrain/parallelisms/parallelize_llama.py | 9 ++-- train.py | 16 +++++- 4 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 torchtrain/meta_init.py diff --git a/.gitignore b/.gitignore index 5e057e7caf..d939038aae 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ out wandb *.model *.json +*.watchman diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py new file mode 100644 index 0000000000..b8ebe6727f --- /dev/null +++ b/torchtrain/meta_init.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +import torch +from torch import nn +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened + +from contextlib import contextmanager +from torchtrain.logging_utils import rank0_log + +@contextmanager +def meta_model_init(): + """ init model on meta device """ + saved_register_parameter = nn.Module.register_parameter + saved_register_buffer = nn.Module.register_buffer + + def register_meta_param(module, name, param): + saved_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + module._parameters[name] = param_cls( + module._parameters[name].to(torch.device("meta")), **kwargs + ) + + def register_meta_buffer(module, name, buffer): + saved_register_buffer(module, name, buffer) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(torch.device("meta")) + + try: + nn.Module.register_parameter = register_meta_param + #nn.Module.register_buffer = register_meta_buffer + yield + finally: + nn.Module.register_parameter = saved_register_parameter + #nn.Module.register_buffer = saved_register_buffer + + +@torch.no_grad() +def meta_to_real_init_fn(module: nn.Module): + + for submodule in module.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if not _is_fsdp_flattened(param) and param.is_meta: + materialized_param = nn.Parameter( + torch.empty_like(param, device=torch.device("cuda")) + ) + rank0_log(f"called to reinflate...{materialized_param=}") + #print(f"called to reinflate...{module=}") + #assert False, "good meta" + setattr(submodule, param_name, materialized_param) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 46bca9ff49..175c3b7d5a 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -19,7 +19,7 @@ from torch.distributed.fsdp.wrap import enable_wrap, wrap from torchtrain.logging_utils import rank0_log - +from torchtrain.meta_init import meta_to_real_init_fn # Uses PTD FSDP AC wrapper def checkpoint_wrapper(module, config): @@ -62,20 +62,21 @@ def parallelize_llama(model, args): # When torch.compile is active, it requires us to set use_orig_params=True "use_orig_params": True, "device_mesh": dp_mesh, + "param_init_fn":meta_to_real_init_fn, } with enable_wrap(wrapper_cls=FSDP, **fsdp_config): for layer_id, transformer_block in enumerate(model.layers): # apply AC to each layer # before wrapping with FSDP, we need to make sure the layer is on GPU - transformer_block = transformer_block.cuda() - transformer_block = checkpoint_wrapper(transformer_block, args) + # todo - config this: transformer_block = transformer_block.cuda() + # todo - transformer_block = checkpoint_wrapper(transformer_block, args) # Wraps each layer with FSDP model.layers[layer_id]= wrap(transformer_block) # wrap the rest layers with FSDP - model = wrap(model.cuda()) + model = wrap(model) # todo - was .cuda() rank0_log(f"Applied parallelisms to the model...") diff --git a/train.py b/train.py index 073e5b1755..1d7c647fd3 100644 --- a/train.py +++ b/train.py @@ -2,9 +2,11 @@ import os from dataclasses import dataclass, field from typing import List +from contextlib import contextmanager # torch imports import torch +import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader @@ -19,7 +21,7 @@ ) from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer from torchtrain.parallelisms import models_parallelize_fns - +from torchtrain.meta_init import meta_model_init @dataclass class TrainState: @@ -40,6 +42,7 @@ def build_optimizer(model, args): return optimizer + def main(args): init_logger() @@ -62,7 +65,13 @@ def main(args): model_config = models_config[model_name][args.model_conf] model_config.vocab_size = tokenizer.n_words - model = model_cls.from_model_args(model_config) + _use_meta_init = True # todo - add to toml + + if _use_meta_init: + with meta_model_init(): + model = model_cls.from_model_args(model_config) + else: + model = model_cls.from_model_args(model_config) # apply PTD parallelisms + AC model = models_parallelize_fns[model_name](model, args) @@ -85,6 +94,9 @@ def main(args): # train loop model.train() + # use fsdp + + with maybe_run_profiler() as torch_profiler: while train_state.step < args.steps or args.steps == -1: train_state.step += 1 From a58aa5f9dfe36143de0c6f1c0cb546c1e70d517e Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 27 Jan 2024 12:00:20 -0800 Subject: [PATCH 2/3] meta init all working --- run_llama_train.sh | 2 +- torchtrain/meta_init.py | 14 ++++------- torchtrain/models/llama/model.py | 2 ++ torchtrain/parallelisms/parallelize_llama.py | 25 ++++++++++++++------ train.py | 23 ++++++++++++------ 5 files changed, 42 insertions(+), 24 deletions(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index 100b52944b..cb15031930 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -9,4 +9,4 @@ NGPU=8 MP=4 torchrun --nproc_per_node=${NGPU} \ -train.py --steps 10 --compile +train.py --steps 10 --compile --meta_init diff --git a/torchtrain/meta_init.py b/torchtrain/meta_init.py index b8ebe6727f..3eafb01221 100644 --- a/torchtrain/meta_init.py +++ b/torchtrain/meta_init.py @@ -6,11 +6,11 @@ from torch.distributed.fsdp._common_utils import _is_fsdp_flattened from contextlib import contextmanager -from torchtrain.logging_utils import rank0_log + @contextmanager def meta_model_init(): - """ init model on meta device """ + """init model on meta device""" saved_register_parameter = nn.Module.register_parameter saved_register_buffer = nn.Module.register_buffer @@ -30,23 +30,19 @@ def register_meta_buffer(module, name, buffer): try: nn.Module.register_parameter = register_meta_param - #nn.Module.register_buffer = register_meta_buffer + nn.Module.register_buffer = register_meta_buffer yield finally: nn.Module.register_parameter = saved_register_parameter - #nn.Module.register_buffer = saved_register_buffer + nn.Module.register_buffer = saved_register_buffer @torch.no_grad() def meta_to_real_init_fn(module: nn.Module): - for submodule in module.modules(): for param_name, param in submodule.named_parameters(recurse=False): if not _is_fsdp_flattened(param) and param.is_meta: materialized_param = nn.Parameter( - torch.empty_like(param, device=torch.device("cuda")) + torch.randn_like(param, device=torch.device("cuda")) ) - rank0_log(f"called to reinflate...{materialized_param=}") - #print(f"called to reinflate...{module=}") - #assert False, "good meta" setattr(submodule, param_name, materialized_param) diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index a3fbeb4bad..feec26aa40 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -24,6 +24,8 @@ class ModelArgs: max_batch_size: int = 32 max_seq_len: int = 32768 + use_meta_init: Optional[bool] = False # controlled via global settings + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 175c3b7d5a..dcebcf2e3e 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -26,11 +26,11 @@ def checkpoint_wrapper(module, config): return ptd_checkpoint_wrapper(module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False) -def parallelize_llama(model, args): +def parallelize_llama(model, args, use_meta_init=False): """ Apply parallelisms to the model, including PTD parallelisms, and AC. - NOTE: the model passed in preferrablably shoule be a meta device model, + NOTE: the model passed in preferrablably should be a meta device model, otherwise the model needs to be small enough on GPU or can fit into CPU. # TODO: apply SP """ @@ -51,6 +51,8 @@ def parallelize_llama(model, args): dp_mesh = world_mesh # apply PTD parallelisms + meta_init_fn = meta_to_real_init_fn if use_meta_init else None + fsdp_config = { "mixed_precision": MixedPrecision( param_dtype=torch.bfloat16, @@ -62,21 +64,30 @@ def parallelize_llama(model, args): # When torch.compile is active, it requires us to set use_orig_params=True "use_orig_params": True, "device_mesh": dp_mesh, - "param_init_fn":meta_to_real_init_fn, + "param_init_fn": meta_init_fn, } with enable_wrap(wrapper_cls=FSDP, **fsdp_config): + + using_meta_init = fsdp_config["param_init_fn"] + for layer_id, transformer_block in enumerate(model.layers): # apply AC to each layer # before wrapping with FSDP, we need to make sure the layer is on GPU - # todo - config this: transformer_block = transformer_block.cuda() - # todo - transformer_block = checkpoint_wrapper(transformer_block, args) + # unless using meta init: + + if not using_meta_init: + transformer_block = transformer_block.cuda() + + transformer_block = checkpoint_wrapper(transformer_block, args) # Wraps each layer with FSDP model.layers[layer_id]= wrap(transformer_block) - # wrap the rest layers with FSDP - model = wrap(model) # todo - was .cuda() + # wrap the remaining layers with FSDP + if not using_meta_init: + model.cuda() + model = wrap(model) rank0_log(f"Applied parallelisms to the model...") diff --git a/train.py b/train.py index 1d7c647fd3..9e8a8cbea7 100644 --- a/train.py +++ b/train.py @@ -2,11 +2,9 @@ import os from dataclasses import dataclass, field from typing import List -from contextlib import contextmanager # torch imports import torch -import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader @@ -23,6 +21,7 @@ from torchtrain.parallelisms import models_parallelize_fns from torchtrain.meta_init import meta_model_init + @dataclass class TrainState: step: int = 0 @@ -42,7 +41,6 @@ def build_optimizer(model, args): return optimizer - def main(args): init_logger() @@ -60,12 +58,15 @@ def main(args): ) # build model - # TODO: add meta initialization + model_cls = model_name_to_cls[model_name] model_config = models_config[model_name][args.model_conf] + model_config.vocab_size = tokenizer.n_words - _use_meta_init = True # todo - add to toml + # meta initialization + _use_meta_init = args.meta_init # todo - add to toml + model_config.use_meta_init = _use_meta_init # append this to model config if _use_meta_init: with meta_model_init(): @@ -74,9 +75,12 @@ def main(args): model = model_cls.from_model_args(model_config) # apply PTD parallelisms + AC - model = models_parallelize_fns[model_name](model, args) + model = models_parallelize_fns[model_name]( + model, args, use_meta_init=_use_meta_init + ) # build optimizer after apply parallelisms to the model + # TODO: add scheduler if needed optimizer = build_optimizer(model, args) @@ -96,7 +100,6 @@ def main(args): # use fsdp - with maybe_run_profiler() as torch_profiler: while train_state.step < args.steps or args.steps == -1: train_state.step += 1 @@ -172,5 +175,11 @@ def main(args): "--compile", action="store_true", help="Whether to compile the model." ) + parser.add_argument( + "--meta_init", + action="store_true", + help="Whether to use meta init for the model.", + ) + args = parser.parse_args() main(args) From 1c97ff555649930573de6fd9f5f97934e226e8ad Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sat, 27 Jan 2024 12:05:55 -0800 Subject: [PATCH 3/3] set default to not using meta_init --- run_llama_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index cb15031930..100b52944b 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -9,4 +9,4 @@ NGPU=8 MP=4 torchrun --nproc_per_node=${NGPU} \ -train.py --steps 10 --compile --meta_init +train.py --steps 10 --compile