diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 7df4b39a60d..bf9206d5498 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -18,6 +18,7 @@ import torch import torch.distributed +from accelerate import init_empty_weights from torch.distributed.fsdp import FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin @@ -197,7 +198,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i else: raise NotImplementedError(f"Unknown architecture {model_config['architectures']}") - with torch.device("meta"): + with init_empty_weights(): save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16) save_model.to_empty(device="cpu")