Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Fix edge cases for FSDP2 integration #1269

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Nov 12, 2024

This PR fixes 2 issues that came up in torchtune when fine-tuning Llama3.2-vision

1. Sometimes there is a strange torch.compile() error with DTensor when there is .grad field. pytorch/torchtune#1978 (comment)

This is actually an old issue #652 (see #652 (comment) for more details). However, it does not always happen: our CI test passes (for PyTorch 2.6, but PyTorch 2.5 has this issue), finetune Llama text (not multimodal) in torchtune has no issues, but finetuning Llama3.2-vision faces the error -> It's not clear why and how this happens. The error message seems to indicate that torch.compile() tries to do dynamic-shape, even though we are explicitly using dynamic=False

The solution is to call .detach() on param, which shares the same weight storage, but now it doesn't have .grad anymore. Thanks to this, low-bit optim + FSDP2 also work for PyTorch 2.5 CI now (previously it didn't).

I can't add a test for this, since I don't know how/when this happens.

2. Wrong DTensor creation when there is uneven sharding (i.e. 1st dim is not divisible by world size)

Usually we don't have uneven shards for LLMs, thus this error didn't surface. However, for ViT, it might be possible due to pos_embed: in some implementation pos_embed includes CLS token, hence the first dim is num_visual_tokens + 1.

The fix is simple: pass shape (and stride) to DTensor.from_local(). An appropriate test has also been added.

Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1269

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit f89f67e with merge base 26648c2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 12, 2024
@gau-nernst gau-nernst added bug Something isn't working topic: bug fix Use this tag for PRs that fix bugs and removed bug Something isn't working labels Nov 12, 2024
@gau-nernst gau-nernst changed the title [low-bit optim] Fix strange compiled Adam step + FSDP2 [low-bit optim] Fix edge cases for FSDP2 integration Nov 13, 2024
@msaroufim msaroufim requested a review from vkuzo November 14, 2024 04:17
@gau-nernst gau-nernst marked this pull request as ready for review November 14, 2024 06:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants