-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve FSDP support for low-bit optimizers #538
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/538
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6cec214 with merge base e8662e0 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@property | ||
def block_size(self): | ||
return self.codes.numel() * 2 // self.scale.numel() | ||
self.block_size = codes.numel() * 2 // scale.numel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious q: Is there some description of the codes/ scales tensor and their relation to each other?
I can see the pattern that codes has .5x (4bit) and 1x (8bit) the bsize * scale numels
But does this assert square blocks?
I think some description here would be helpful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add some description. Basically for 8-bit and FP8, codes
has the same shape as the "outer shape", while for 4-bit, since there is bit-packing, I find that it's easier to let codes
be a flattened 1D buffer and keep track of the shape manually.
To get the scale, the float tensor is actually flattened first and reshape to (-1, block_size)
. This is done to relax the requirement that the last dimension must be divisible by block_size
-> now we only need numel (total size) to be divisible by block_size
. This is especially needed when block size is large (8-bit optim uses block_size=2048 as done in bnb). Since optim update is element-wise, we don't really need to care if the original tensor is 1D, 2D, or n-D (well, maybe there is some structure in n-D tensor that flattening it might not be so wise). I believe the original implementation in bnb does this as well.
-> scale
is always a 1D tensor, with size=original_tensor.numel() // block_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg Added some docs. Lmk if it is still unclear.
The |
I don't use FSDP at work or personal projects because I don't have access to multi-GPU machines, so can't really answer your question 😅. Only added FSDP support for low bit optimizers due to request from people. At least in torchtune, I saw that |
@@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape | |||
) | |||
|
|||
def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): | |||
"""Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw the link is not valid, can you remove .
in the end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): | |||
) | |||
|
|||
def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool): | |||
"""Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! We should look to get the low bit optimizers out of prototype soon!
DTensor.from_local(run_check=False)
to wrap quantized optim state (instead of swapping_local_tensor
)block_size
a fixed attribute, calculated inside__init__
(instead of dynamically calculate every time)all_gather_into_tensor
andwait_tensor
to supportDTensor.full_tensor()