Remove FSDP wrapping from sub-models.#34452
Conversation
SunMarc
left a comment
There was a problem hiding this comment.
Thnaks for fixing the issue @eljandoubi ! Do you think there is a simpler way to handle this edge case @muellerzr ?
src/transformers/trainer.py
Outdated
There was a problem hiding this comment.
You can use unwarp_model function in transformers instead. Also, why do we need to set recursive to True ? Also, please leave a comment above as this specific path is only to make it functional with auto_find_batch_size .
There was a problem hiding this comment.
unwrap_model does not provide access to the recursive argument. Auto-wrap policies wrap submodules with FSDP, and unwrap_model is unable to remove them. You can test this on the toy example from the PyTorch FSDP tutorial for rank=0 and world_size=1, then experiment with the line I provided in a notebook.
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=20000
)
torch.cuda.set_device(rank)
model = Net().to(rank)
print(model)
fsdp_model = FSDP(model,
auto_wrap_policy=my_auto_wrap_policy)
print(fsdp_model)
unwrap_model = unwarp_model(fsdp_model)
print(unwrap_model)
VS
You need to reinstantiates model and fsdp_model:
model = Net().to(rank)
fsdp_model = FSDP(model,
auto_wrap_policy=my_auto_wrap_policy)
extract_model = extract_model_from_parallel(fsdp_model, recursive=True)
print(extract_model)
There was a problem hiding this comment.
I'm talking about this function in transformers. It uses extract_model_from_parallel under the hood so it should be comparable.
|
@SunMarc @muellerzr Did you get a different result than I did? |
muellerzr
left a comment
There was a problem hiding this comment.
Thanks for the fix, can you add a test in tests/test_trainer.py for this?
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
SunMarc
left a comment
There was a problem hiding this comment.
Thanks ! Left an suggestion for unwrap_model
|
@SunMarc I migrated to |
LysandreJik
left a comment
There was a problem hiding this comment.
Let's merge it if you're both ok with it @SunMarc @muellerzr
|
Please rebase this PR on main in order to pass the CI @eljandoubi ! |
693ba36 to
0df20d6
Compare
|
@SunMarc @LysandreJik @muellerzr Is there any update on the pull request? |
|
We were on a company wide offsite! merging as they all approved 🤗 |
* Remove FSDP wrapping from sub-models. * solve conflict trainer.py * make fixup * add unit test for fsdp_auto_wrap_policy when using auto_find_batch_size * put back extract_model_from_parallel * use transformers unwrap_model
What does this PR do?
Fixes #34113
Who can review?
Library: