Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FSDP support for low-bit optimizers #538

Merged
merged 7 commits into from
Jul 26, 2024

Conversation

gau-nernst
Copy link
Collaborator

  • Use DTensor.from_local(run_check=False) to wrap quantized optim state (instead of swapping _local_tensor)
  • Make block_size a fixed attribute, calculated inside __init__ (instead of dynamically calculate every time)
  • Implement all_gather_into_tensor and wait_tensor to support DTensor.full_tensor()

Copy link

pytorch-bot bot commented Jul 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/538

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6cec214 with merge base e8662e0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2024
@gau-nernst gau-nernst marked this pull request as ready for review July 25, 2024 03:27
@property
def block_size(self):
return self.codes.numel() * 2 // self.scale.numel()
self.block_size = codes.numel() * 2 // scale.numel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious q: Is there some description of the codes/ scales tensor and their relation to each other?

I can see the pattern that codes has .5x (4bit) and 1x (8bit) the bsize * scale numels
But does this assert square blocks?
I think some description here would be helpful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add some description. Basically for 8-bit and FP8, codes has the same shape as the "outer shape", while for 4-bit, since there is bit-packing, I find that it's easier to let codes be a flattened 1D buffer and keep track of the shape manually.
To get the scale, the float tensor is actually flattened first and reshape to (-1, block_size). This is done to relax the requirement that the last dimension must be divisible by block_size -> now we only need numel (total size) to be divisible by block_size. This is especially needed when block size is large (8-bit optim uses block_size=2048 as done in bnb). Since optim update is element-wise, we don't really need to care if the original tensor is 1D, 2D, or n-D (well, maybe there is some structure in n-D tensor that flattening it might not be so wise). I believe the original implementation in bnb does this as well.
-> scale is always a 1D tensor, with size=original_tensor.numel() // block_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Added some docs. Lmk if it is still unclear.

@awgu
Copy link
Contributor

awgu commented Jul 25, 2024

The DTensor-related changes look good to me. I wonder, did you ever have to run .full_tensor() in the training path, or was it only used outside e.g. for debugging?

@gau-nernst
Copy link
Collaborator Author

I don't use FSDP at work or personal projects because I don't have access to multi-GPU machines, so can't really answer your question 😅. Only added FSDP support for low bit optimizers due to request from people.

At least in torchtune, I saw that .full_tensor() is used to retrieve full optim state dict before saving checkpoint https://github.com/pytorch/torchtune/blob/0057fe7cf83e14f0b62538a8d4d20719e0a88639/torchtune/utils/_distributed.py#L437 (though it might be unnecessary, or even less efficient, compared to saving each shard separately, provided we resume training with the same setup).

@@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape):
"""Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw the link is not valid, can you remove . in the end?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
"""Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this one

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We should look to get the low bit optimizers out of prototype soon!

@msaroufim msaroufim merged commit 4280843 into pytorch:main Jul 26, 2024
13 checks passed
@gau-nernst gau-nernst deleted the improve_low_bit_optim branch July 26, 2024 02:49
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants