diff --git a/src/speculators/train/trainer.py b/src/speculators/train/trainer.py index a926e378f..847a0db18 100644 --- a/src/speculators/train/trainer.py +++ b/src/speculators/train/trainer.py @@ -8,6 +8,7 @@ StateDictOptions, set_model_state_dict, ) +from torch.distributed.fsdp import FSDPModule from torch.utils.data import DataLoader from tqdm import TqdmExperimentalWarning from tqdm.rich import tqdm @@ -49,6 +50,33 @@ class TrainerConfig(NamedTuple): log_freq: int = 1 +def _materialize_fsdp_model(model: torch.nn.Module): + """Materialize and reset parameters for a freshly sharded FSDP model.""" + acc = torch.accelerator.current_accelerator() + device = "cuda" if acc is None else acc.type + + for m in model.layers.children(): # type: ignore[union-attr] + if not isinstance(m, FSDPModule): + continue + m.to_empty(device=device) # type: ignore[attr-defined] + for sub_module in m.modules(): # type: ignore[attr-defined] + if hasattr(sub_module, "reset_parameters"): + sub_module.reset_parameters() # type: ignore[operator] + + for name, module in model.named_children(): + if name == "layers": + continue + tensors = list(module.parameters(recurse=True)) + list( + module.buffers(recurse=True) + ) + has_meta = any(t.device.type == "meta" for t in tensors) + if has_meta: + module.to_empty(device=device) + for sub in module.modules(): + if hasattr(sub, "reset_parameters"): + sub.reset_parameters() # type: ignore[operator] + + class Trainer: def __init__( self, @@ -127,6 +155,7 @@ def setup_model(self): if load_checkpoint: self.checkpointer.load_model_state_dict(self.model) else: + _materialize_fsdp_model(self.model) # Broadcast full state dict from rank 0 to all ranks set_model_state_dict( self.model,