Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ out
wandb
*.model
*.json
*.watchman
48 changes: 48 additions & 0 deletions torchtrain/meta_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

import torch
from torch import nn
from torch.distributed.fsdp._common_utils import _is_fsdp_flattened

from contextlib import contextmanager


@contextmanager
def meta_model_init():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I don't think we need this, a simple init like this should work:

with torch.device("meta"):
    model = Model.from_args(...)

"""init model on meta device"""
saved_register_parameter = nn.Module.register_parameter
saved_register_buffer = nn.Module.register_buffer

def register_meta_param(module, name, param):
saved_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
module._parameters[name] = param_cls(
module._parameters[name].to(torch.device("meta")), **kwargs
)

def register_meta_buffer(module, name, buffer):
saved_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(torch.device("meta"))

try:
nn.Module.register_parameter = register_meta_param
nn.Module.register_buffer = register_meta_buffer
yield
finally:
nn.Module.register_parameter = saved_register_parameter
nn.Module.register_buffer = saved_register_buffer


@torch.no_grad()
def meta_to_real_init_fn(module: nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: so we don't define reset_parameters on the nn.Module and instead we are using this function, is it because the reset_parameters does not work with FSDP?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reset_parameters() should work (now) with FSDP as long as reset_parameters() only initializes the module's directly owned parameters/buffers and not any of children.

for submodule in module.modules():
for param_name, param in submodule.named_parameters(recurse=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a TODO here and link the issue. I think doing random init might not be good enough fore pretraining and we should resolve the layer depth init functions.

if not _is_fsdp_flattened(param) and param.is_meta:
materialized_param = nn.Parameter(
torch.randn_like(param, device=torch.device("cuda"))
)
setattr(submodule, param_name, materialized_param)
2 changes: 2 additions & 0 deletions torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ModelArgs:
max_batch_size: int = 32
max_seq_len: int = 32768

use_meta_init: Optional[bool] = False # controlled via global settings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is more like a TrainOptions instead of a model specific config, let's not include it here.



class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
Expand Down
24 changes: 18 additions & 6 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
from torch.distributed.fsdp.wrap import enable_wrap, wrap

from torchtrain.logging_utils import rank0_log

from torchtrain.meta_init import meta_to_real_init_fn

# Uses PTD FSDP AC wrapper
def checkpoint_wrapper(module, config):
return ptd_checkpoint_wrapper(module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False)


def parallelize_llama(model, args):
def parallelize_llama(model, args, use_meta_init=False):
"""
Apply parallelisms to the model, including PTD parallelisms, and AC.

NOTE: the model passed in preferrablably shoule be a meta device model,
NOTE: the model passed in preferrablably should be a meta device model,
otherwise the model needs to be small enough on GPU or can fit into CPU.
# TODO: apply SP
"""
Expand All @@ -51,6 +51,8 @@ def parallelize_llama(model, args):
dp_mesh = world_mesh

# apply PTD parallelisms
meta_init_fn = meta_to_real_init_fn if use_meta_init else None

fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
Expand All @@ -62,20 +64,30 @@ def parallelize_llama(model, args):
# When torch.compile is active, it requires us to set use_orig_params=True
"use_orig_params": True,
"device_mesh": dp_mesh,
"param_init_fn": meta_init_fn,
}

with enable_wrap(wrapper_cls=FSDP, **fsdp_config):

using_meta_init = fsdp_config["param_init_fn"]

for layer_id, transformer_block in enumerate(model.layers):
# apply AC to each layer
# before wrapping with FSDP, we need to make sure the layer is on GPU
transformer_block = transformer_block.cuda()
# unless using meta init:

if not using_meta_init:
transformer_block = transformer_block.cuda()

transformer_block = checkpoint_wrapper(transformer_block, args)

# Wraps each layer with FSDP
model.layers[layer_id]= wrap(transformer_block)

# wrap the rest layers with FSDP
model = wrap(model.cuda())
# wrap the remaining layers with FSDP
if not using_meta_init:
model.cuda()
model = wrap(model)

rank0_log(f"Applied parallelisms to the model...")

Expand Down
27 changes: 24 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer
from torchtrain.parallelisms import models_parallelize_fns
from torchtrain.meta_init import meta_model_init


@dataclass
Expand Down Expand Up @@ -57,17 +58,29 @@ def main(args):
)

# build model
# TODO: add meta initialization

model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][args.model_conf]

model_config.vocab_size = tokenizer.n_words

model = model_cls.from_model_args(model_config)
# meta initialization
_use_meta_init = args.meta_init # todo - add to toml
model_config.use_meta_init = _use_meta_init # append this to model config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why we save the meta init flag to the model config? I think it's sth orthogonal to the model arch?


if _use_meta_init:
with meta_model_init():
model = model_cls.from_model_args(model_config)
else:
model = model_cls.from_model_args(model_config)

# apply PTD parallelisms + AC
model = models_parallelize_fns[model_name](model, args)
model = models_parallelize_fns[model_name](
model, args, use_meta_init=_use_meta_init
)

# build optimizer after apply parallelisms to the model

# TODO: add scheduler if needed
optimizer = build_optimizer(model, args)

Expand All @@ -85,6 +98,8 @@ def main(args):
# train loop
model.train()

# use fsdp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove this?


with maybe_run_profiler() as torch_profiler:
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
Expand Down Expand Up @@ -160,5 +175,11 @@ def main(args):
"--compile", action="store_true", help="Whether to compile the model."
)

parser.add_argument(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm wondering if we should even make this optional. It might be cleaner if we just always do meta init. then we apply various parallelisms, then we materialize. thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a great idea imo...not just for being cleaner but b/c meta-init has been a bit of a second class citizen last year, yet we see lots of partners struggling with larger model training due to OOM on loading and unclear how to leverage meta init.
Thus, always using meta init here also would ensure full support (i.e. any new feature that breaks with it) and provide an always working, proper code example.

"--meta_init",
action="store_true",
help="Whether to use meta init for the model.",
)

args = parser.parse_args()
main(args)