Skip to content

Model init with HuggingFace model #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
neeldani opened this issue Dec 16, 2024 · 15 comments
Open

Model init with HuggingFace model #743

neeldani opened this issue Dec 16, 2024 · 15 comments
Assignees
Labels
bug Something isn't working huggingface integration module: checkpoint question Further information is requested

Comments

@neeldani
Copy link

neeldani commented Dec 16, 2024

I am writing a simple script to run FSDP2 (fully_shard) on the pythia-1b model available on HuggingFace. I am currently running the model on 1 node with 2 devices. I was following the meta-device initialisation from the FSDP2 docs. However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB). Further, I get an OOM on my device when I try with pythia-2.8b model. Following is a snippet on how I am initialising the model on a meta device using HuggingFace APIs:

model_name = "EleutherAI/pythia-14m"
    
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
config = AutoConfig.from_pretrained(model_name)
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)

    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)

    model = load_checkpoint_and_dispatch(
        model, path_to_safe_tensors
    )

This is not very straightforward since the shards expect DTensors when the weights are being loaded via load_checkpoint_and_dispatch. I am looking for some suggestions on what would be a good way to make FSDP2 work with HuggingFace models. I dont think accelerate supports FSDP2 yet.

@awgu
Copy link
Collaborator

awgu commented Dec 16, 2024

cc: @weifengpy @mori360

@neeldani
Copy link
Author

👋 Gentle bump on this - mainly to see if there is some workaround for the above issue 👀

@mori360
Copy link
Contributor

mori360 commented Dec 20, 2024

However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB).

It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)

I get an OOM on my device when I try with pythia-2.8b model

Could you give more details on the safe_tensors as I could repro the huge memory cost.
Also, could you give a device flow so that I could follow up when you switch you devices to gpu.

@neeldani
Copy link
Author

neeldani commented Dec 23, 2024

It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)

I see. Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the full_state_dict

Could you give more details on the safe_tensors as I could repro the huge memory cost.

I downloaded the model.safetensors for the pythia-1b model from here. These weights are not sharded

Also, could you give a device flow so that I could follow up when you switch you devices to gpu.

I am trying to mimic TorchTitan's implementation but with a HuggingFace model

  1. Load the empty model on the meta device
  2. Apply fsdp, move sharded weights to the respective GPUs and materialise the weights
  3. re-initialise the sharded weights on each GPU

This is a simple repro of my implementation which can be run using:

torchrun --nnodes=1 --nproc_per_node=2 reproduce.py

import os

import torch
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed._composable.fsdp import fully_shard
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
    num_params = sum(p.numel() for p in model.parameters())
    if exclude_embedding:
        num_params -= model.tok_embeddings.weight.numel()
    return num_params

def setup(local_rank, world_size):
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    init_process_group("nccl", rank=local_rank, world_size=world_size)

def load():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    setup(local_rank, world_size)

    model_name = "EleutherAI/pythia-2.8b"
    config = AutoConfig.from_pretrained(model_name)
    
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)
    
    if local_rank == 0:
        print("Load models with empty weights")
        print("Device: ", model.device)
        print("Params: ", get_num_params(model))
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))

    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)

    if local_rank == 0:
        print("Applied FSDP to the model")
        print("Device: ", model.device)
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))
        print("# of params: ", get_num_params(model))

    model.to_empty(device='cuda')

    if local_rank == 0:
        print("Materialized the sharded tensors")
        print("Device: ", model.device)
        print("Peak mem: ", torch.cuda.max_memory_allocated() / (1024 ** 3))
        print("# of params: ", get_num_params(model))

    model = load_checkpoint_and_dispatch(model, "model.safetensors", device_map="auto", no_split_module_classes="GPTNeoXLayer")

if __name__ == "__main__":
    load()

The flow is very similar to that of TorchTitan's except that TorchTitan makes an explicit call to re-initialise the weights after materialising them. Since I wish to load weights from a pretrained HF model, its a bit challenging. The above code throws an error where I call load_checkpoint_and_dispatch since the model expects DTensors as inputs.

@mori360
Copy link
Contributor

mori360 commented Dec 27, 2024

Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the full_state_dict

torch.distributed.checkpoint.state_dict.set_model_state_dict could load the sharded model without loading the full_state_dict at one time as it conducts loading param by param to avoid the memory peak(to help avoid the OOM).

However, accelerate.load_checkpoint_and_dispatch does not support sharded model right now, without condition for param_cls.__name__ in [''DTensor"] to conduct distribute

@fegin Please correct me if I'm wrong. Also, shall we update model.init_weight() in torchtitan in the process from model.init_weight() to checkpoint.load() to to init weight param by param?

@tianyu-l tianyu-l added question Further information is requested bug Something isn't working labels Jan 7, 2025
@fegin
Copy link
Contributor

fegin commented Jan 8, 2025

Yes, @mori360, as you have implemented this feature, OOM should be able to avoid with set_model_state_dict. But we will need the state_dict to be loaded with DCP and set_model_state_dict.

@Hannibal046
Copy link

Hi, any progress here? What is the best practice to continue pretrain a HF model with torchtitan?

@yzhangcs
Copy link
Contributor

yzhangcs commented Feb 18, 2025

@neeldani Regarding your orginal issue, for now, the easiest approach would be to:

  1. Convert your HF model weights to the DCP format and save them in <path>/checkpoint/step-0. You can follow the instructions in this guide: How to Convert a LLaMA 3 Checkpoint for Use in TorchTitan. Replace <path> with your desired save location.
  2. Once the weights are converted, you can resume training directly by setting --training.load_step 0, similar to how you would with a seed checkpoint.

Does this make sense? @mori360 @fegin @tianyu-l @huyiwen, please correct me if I missed anything.

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 19, 2025

Thanks @yzhangcs

What is the best practice to continue pretrain a HF model with torchtitan?

I think the key thing to do is to convert a HF checkpoint into a DCP checkpoint, like what this script does #305 (comment)

I heard that DCP is going to support HF checkpointing format, but it may take some time to happen.
related PR for the non-distributed use case: pytorch/pytorch#146352
cc: @fegin @kwen2501 to confirm

@yzhangcs
Copy link
Contributor

@tianyu-l I just wrote one for medium/small-sized models https://github.com/fla-org/flame/blob/main/convert_hf_to_dcp.py
like https://github.com/pytorch/torchtitan/blob/main/scripts/convert_llama_to_dcp.py.
I’m using the converted DCPs to finetune the Qwen model on finweb-edu, and everything appears to be working as expected so far.

Image

@mingdianliu
Copy link

mingdianliu commented Apr 5, 2025

@neeldani @fegin @yzhangcs @awgu @Hannibal046 @tianyu-l @mori360

Dear All,

Thanks for making FSDP2 compatible with Huggingface model. However, I meet with an issue while running the reproduce code. Just want to know if you have any insights for this issue.

import os

import torch
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed._composable.fsdp import fully_shard
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int:
    num_params = sum(p.numel() for p in model.parameters())
    if exclude_embedding:
        num_params -= model.tok_embeddings.weight.numel()
    return num_params

def setup(local_rank, world_size):
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)
    init_process_group("nccl", rank=local_rank, world_size=world_size)

def load():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    setup(local_rank, world_size)

    model_name = "EleutherAI/pythia-2.8b"
    config = AutoConfig.from_pretrained(model_name)
    
    with init_empty_weights():
        model = AutoModelForCausalLM.from_config(config)
    
    for module in model.modules():
        if isinstance(module, GPTNeoXLayer):
            fully_shard(module)
    
    model = fully_shard(model, reshard_after_forward=True)
    model.to_empty(device='cuda')


if __name__ == "__main__":
    load()

The error is below:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/NCCL/report_issue.py](/NCCL/report_issue.py)", line 41, in <module>
[rank0]:     load()
[rank0]:   File "/workspace/NCCL/report_issue.py](/NCCL/report_issue.py)", line 34, in load
[rank0]:     fully_shard(module)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/contract.py", line 107, in wrapper
[rank0]:     updated = func(module, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/fully_shard.py", line 114, in fully_shard
[rank0]:     _move_states_to_device(params, buffers, device, mesh_info)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/_composable/fsdp/_fsdp_init.py", line 143, in _move_states_to_device
[rank0]:     tensor.data = [tensor.to](http://tensor.to/)(device)
[rank0]: NotImplementedError: Cannot copy out of meta tensor; no data!

Python command:
torchrun --nnodes=1 --nproc_per_node=8 reproduce.py

@awgu
Copy link
Collaborator

awgu commented Apr 5, 2025

@mingdianliu which version of PyTorch are you using? maybe you need a newer version

@mingdianliu
Copy link

mingdianliu commented Apr 5, 2025

@mingdianliu which version of PyTorch are you using? maybe you need a newer version

@awgu Thank you very much! After upgrading pytorch to 2.6.0, the code is working on my side. I have one more follow-up question.

I have followed your instruction to convert HF ckpt to DCP ckpt. However, it takes too long time to load DCP ckpt (540 seconds for Qwen2-VL-7B model on 2 nodes 16 GPUs) with torch.distributed.checkpoint.load(state_dict, checkpoint_id=None, storage_reader=None). Is there any better method I can leverage to accelerate the ckpt loading process?

In the code, I am using model.load_state_dict() to load the state_dict(), which has a comparable latency as set_model_state_dict().

import os

import torch
from torch.distributed import init_process_group, destroy_process_group
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.device_mesh import init_device_mesh

from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq
from transformers import (
    AutoConfig,
    Qwen2VLForConditionalGeneration,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLDecoderLayer, Qwen2VLVisionBlock


def load():
    
    distributed_backend = "nccl" # gloo for cpu
    init_process_group(distributed_backend)

    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    model_name = "Qwen/Qwen2-VL-2B-Instruct"
    revision = "895c3a49bc3fa70a340399125c650a463535e71c"
    # model_name = "Qwen/Qwen2-VL-7B-Instruct"
    # revision = "a28a094eb66a9f2ac70eef346f040d8a79977472"
    # model_name = "Qwen/Qwen2-VL-72B-Instruct"
    # revision = "f9b556a74d58e6d9915f73227c21045c87342b42"

    config = AutoConfig.from_pretrained(
        model_name, 
        revision=revision, 
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2")
    
    device_mesh = init_device_mesh("cuda", (world_size,))

    with torch.device("meta"):
        model = AutoModelForVision2Seq.from_config(config)

    for module in model.modules():
        if isinstance(module, Qwen2VLDecoderLayer):
            fully_shard(module, mesh=device_mesh, reshard_after_forward=True)
    
    model = fully_shard(model, mesh=device_mesh, reshard_after_forward=True)

    model.to_empty(device='cuda')

    model_state_dict = model.state_dict()
    model_dir = "path_to_DCP_ckpt_dir/2B"

    print("start torch.distributed.checkpoint.load")
    fs_storage_reader = FileSystemReader(model_dir)
    torch.distributed.checkpoint.load(
        state_dict=model_state_dict,
        storage_reader=fs_storage_reader,
        )

    model.load_state_dict(model_state_dict)

    print("Model loaded")

if __name__ == "__main__":
    load()
    destroy_process_group()

Python command:
torchrun --nnodes=2 --nproc_per_node=8 reproduce.py

Actually, I also have a shot at model = load_checkpoint_and_dispatch(model, checkpoint=model_dir, device_map="auto", no_split_module_classes=["Qwen2VLDecoderLayer"], dtype=torch.bfloat16,). But I will run into the following error:

[rank7]: Traceback (most recent call last):                                                                                                                                                       
[rank7]:   File "/workspace/NCCL/reproduce.py", line 204, in <module>
[rank7]:     load()
[rank7]:   File "/workspace/NCCL/reproduce.py", line 159, in load
[rank7]:     model = load_checkpoint_and_dispatch(
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/big_modeling.py", line 620, in load_checkpoint_and_dispatch
[rank7]:     load_checkpoint_in_model(
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/modeling.py", line 1982, in load_checkpoint_in_model
[rank7]:     set_module_tensor_to_device(
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/modeling.py", line 377, in set_module_tensor_to_device
[rank7]:     new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 32, in inner
[rank7]:     return disable_fn(*args, **kwargs)
[rank7]:   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank7]:     return fn(*args, **kwargs)
[rank7]: TypeError: DTensor.__new__() missing 1 required positional argument: 'spec'

@mingdianliu
Copy link

Dear community,

Thanks for your replies. This issue has been resolved. The loading process is pretty slow due to a low-performing dish in which I save the DCP checkpoint. After switching to a good disk, 72B model can be loaded although the loading time is a little long. I will have a try on optimizing the loading time. If there is any optimization progress, I will post it here.

@fegin
Copy link
Contributor

fegin commented Apr 15, 2025

@mingdianliu We are exploring an offline resharding converter to speed up the loading time, #1104.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working huggingface integration module: checkpoint question Further information is requested
Projects
None yet
Development

No branches or pull requests

8 participants