-
Notifications
You must be signed in to change notification settings - Fork 339
Model init with HuggingFace model #743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc: @weifengpy @mori360 |
👋 Gentle bump on this - mainly to see if there is some workaround for the above issue 👀 |
It depends on where you have the peak memory. If it's on fully_shard, then the full_state_dict would shard to a local_state_dict, causing a greater memory. (full_state_dict + local_state_dict > full_state_dict)
Could you give more details on the safe_tensors as I could repro the huge memory cost. |
I see. Ideally I am looking for an approach which allows me to load the sharded models on each GPU without loading the
I downloaded the model.safetensors for the
I am trying to mimic TorchTitan's implementation but with a HuggingFace model
This is a simple repro of my implementation which can be run using:
The flow is very similar to that of TorchTitan's except that TorchTitan makes an explicit call to re-initialise the weights after materialising them. Since I wish to load weights from a pretrained HF model, its a bit challenging. The above code throws an error where I call |
However, @fegin Please correct me if I'm wrong. Also, shall we update model.init_weight() in torchtitan in the process from model.init_weight() to checkpoint.load() to to init weight param by param? |
Yes, @mori360, as you have implemented this feature, OOM should be able to avoid with |
Hi, any progress here? What is the best practice to continue pretrain a HF model with torchtitan? |
@neeldani Regarding your orginal issue, for now, the easiest approach would be to:
Does this make sense? @mori360 @fegin @tianyu-l @huyiwen, please correct me if I missed anything. |
Thanks @yzhangcs
I think the key thing to do is to convert a HF checkpoint into a DCP checkpoint, like what this script does #305 (comment) I heard that DCP is going to support HF checkpointing format, but it may take some time to happen. |
@tianyu-l I just wrote one for medium/small-sized models https://github.com/fla-org/flame/blob/main/convert_hf_to_dcp.py |
@neeldani @fegin @yzhangcs @awgu @Hannibal046 @tianyu-l @mori360 Dear All, Thanks for making FSDP2 compatible with Huggingface model. However, I meet with an issue while running the reproduce code. Just want to know if you have any insights for this issue.
The error is below:
Python command: |
@mingdianliu which version of PyTorch are you using? maybe you need a newer version |
@awgu Thank you very much! After upgrading pytorch to 2.6.0, the code is working on my side. I have one more follow-up question. I have followed your instruction to convert HF ckpt to DCP ckpt. However, it takes too long time to load DCP ckpt (540 seconds for Qwen2-VL-7B model on 2 nodes 16 GPUs) with torch.distributed.checkpoint.load(state_dict, checkpoint_id=None, storage_reader=None). Is there any better method I can leverage to accelerate the ckpt loading process? In the code, I am using model.load_state_dict() to load the state_dict(), which has a comparable latency as set_model_state_dict().
Python command: Actually, I also have a shot at
|
Dear community, Thanks for your replies. This issue has been resolved. The loading process is pretty slow due to a low-performing dish in which I save the DCP checkpoint. After switching to a good disk, 72B model can be loaded although the loading time is a little long. I will have a try on optimizing the loading time. If there is any optimization progress, I will post it here. |
@mingdianliu We are exploring an offline resharding converter to speed up the loading time, #1104. |
I am writing a simple script to run FSDP2 (
fully_shard
) on thepythia-1b
model available on HuggingFace. I am currently running the model on 1 node with 2 devices. I was following the meta-device initialisation from the FSDP2 docs. However, I think there is something wrong with my implementation since the peak memory usage with FSDP is same as without FSDP (~ 1GB). Further, I get an OOM on my device when I try withpythia-2.8b
model. Following is a snippet on how I am initialising the model on a meta device using HuggingFace APIs:This is not very straightforward since the shards expect
DTensors
when the weights are being loaded viaload_checkpoint_and_dispatch
. I am looking for some suggestions on what would be a good way to make FSDP2 work with HuggingFace models. I dont think accelerate supports FSDP2 yet.The text was updated successfully, but these errors were encountered: