-
Notifications
You must be signed in to change notification settings - Fork 631
Add meta init #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add meta init #21
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,3 +11,4 @@ out | |
| wandb | ||
| *.model | ||
| *.json | ||
| *.watchman | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch.distributed.fsdp._common_utils import _is_fsdp_flattened | ||
|
|
||
| from contextlib import contextmanager | ||
|
|
||
|
|
||
| @contextmanager | ||
| def meta_model_init(): | ||
| """init model on meta device""" | ||
| saved_register_parameter = nn.Module.register_parameter | ||
| saved_register_buffer = nn.Module.register_buffer | ||
|
|
||
| def register_meta_param(module, name, param): | ||
| saved_register_parameter(module, name, param) | ||
| if param is not None: | ||
| param_cls = type(module._parameters[name]) | ||
| kwargs = module._parameters[name].__dict__ | ||
| module._parameters[name] = param_cls( | ||
| module._parameters[name].to(torch.device("meta")), **kwargs | ||
| ) | ||
|
|
||
| def register_meta_buffer(module, name, buffer): | ||
| saved_register_buffer(module, name, buffer) | ||
| if buffer is not None: | ||
| module._buffers[name] = module._buffers[name].to(torch.device("meta")) | ||
|
|
||
| try: | ||
| nn.Module.register_parameter = register_meta_param | ||
| nn.Module.register_buffer = register_meta_buffer | ||
| yield | ||
| finally: | ||
| nn.Module.register_parameter = saved_register_parameter | ||
| nn.Module.register_buffer = saved_register_buffer | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def meta_to_real_init_fn(module: nn.Module): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: so we don't define
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| for submodule in module.modules(): | ||
| for param_name, param in submodule.named_parameters(recurse=False): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a TODO here and link the issue. I think doing random init might not be good enough fore pretraining and we should resolve the layer depth init functions. |
||
| if not _is_fsdp_flattened(param) and param.is_meta: | ||
| materialized_param = nn.Parameter( | ||
| torch.randn_like(param, device=torch.device("cuda")) | ||
| ) | ||
| setattr(submodule, param_name, materialized_param) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,8 @@ class ModelArgs: | |
| max_batch_size: int = 32 | ||
| max_seq_len: int = 32768 | ||
|
|
||
| use_meta_init: Optional[bool] = False # controlled via global settings | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is more like a TrainOptions instead of a model specific config, let's not include it here. |
||
|
|
||
|
|
||
| class RMSNorm(torch.nn.Module): | ||
| def __init__(self, dim: int, eps: float = 1e-6): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| ) | ||
| from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer | ||
| from torchtrain.parallelisms import models_parallelize_fns | ||
| from torchtrain.meta_init import meta_model_init | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -57,17 +58,29 @@ def main(args): | |
| ) | ||
|
|
||
| # build model | ||
| # TODO: add meta initialization | ||
|
|
||
| model_cls = model_name_to_cls[model_name] | ||
| model_config = models_config[model_name][args.model_conf] | ||
|
|
||
| model_config.vocab_size = tokenizer.n_words | ||
|
|
||
| model = model_cls.from_model_args(model_config) | ||
| # meta initialization | ||
| _use_meta_init = args.meta_init # todo - add to toml | ||
| model_config.use_meta_init = _use_meta_init # append this to model config | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious why we save the meta init flag to the model config? I think it's sth orthogonal to the model arch? |
||
|
|
||
| if _use_meta_init: | ||
| with meta_model_init(): | ||
| model = model_cls.from_model_args(model_config) | ||
| else: | ||
| model = model_cls.from_model_args(model_config) | ||
|
|
||
| # apply PTD parallelisms + AC | ||
| model = models_parallelize_fns[model_name](model, args) | ||
| model = models_parallelize_fns[model_name]( | ||
| model, args, use_meta_init=_use_meta_init | ||
| ) | ||
|
|
||
| # build optimizer after apply parallelisms to the model | ||
|
|
||
| # TODO: add scheduler if needed | ||
| optimizer = build_optimizer(model, args) | ||
|
|
||
|
|
@@ -85,6 +98,8 @@ def main(args): | |
| # train loop | ||
| model.train() | ||
|
|
||
| # use fsdp | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove this? |
||
|
|
||
| with maybe_run_profiler() as torch_profiler: | ||
| while train_state.step < args.steps or args.steps == -1: | ||
| train_state.step += 1 | ||
|
|
@@ -160,5 +175,11 @@ def main(args): | |
| "--compile", action="store_true", help="Whether to compile the model." | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm wondering if we should even make this optional. It might be cleaner if we just always do meta init. then we apply various parallelisms, then we materialize. thoughts?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually a great idea imo...not just for being cleaner but b/c meta-init has been a bit of a second class citizen last year, yet we see lots of partners struggling with larger model training due to OOM on loading and unclear how to leverage meta init. |
||
| "--meta_init", | ||
| action="store_true", | ||
| help="Whether to use meta init for the model.", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
| main(args) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I don't think we need this, a simple init like this should work: