diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 210b85e15f4..2258d764721 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -47,6 +47,7 @@ import numpy as np import torch +from accelerate import init_empty_weights from safetensors.torch import load_file from torch.distributed._tensor import Placement, Shard from transformers import ( @@ -128,7 +129,7 @@ def patch_model_generation_config(self, model): def save_hf_model_and_tokenizer(self, state_dict: dict[str, torch.Tensor]): auto_model_class = self.get_transformers_auto_model_class() - with torch.device("meta"): + with init_empty_weights(): model = auto_model_class.from_config(self.model_config, torch_dtype=torch.bfloat16) model.to_empty(device="cpu") model = self.patch_model_generation_config(model)