-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for quantize_() with Float8Linear module #1344
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1344
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit f66e7ed with merge base ed76e9c (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think makes sense, can you add some tests
Some alternatives would be to maintain the Float8Linear structure and then swap the weight class dtype (feels weird).
I think this is good argument for subclasses since you can maintain structure and then assert all low_p subclasses have a dequant
method and call that which will convert to fp32. I know Christian wants tensor.to(torch.float32) to do this but I think its tooo magical
cc @vkuzo
torchao/quantization/quant_api.py
Outdated
@@ -222,6 +224,9 @@ def _replace_with_custom_fn_if_matches_filter( | |||
Returns: | |||
None | |||
""" | |||
# If model is Float8Linear, convert it to Linear before moving forward | |||
if isinstance(model, Float8Linear): | |||
model = dequantize_float8_training(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you just move your code snippet from the other file here:
if isinstance(model, Float8Linear):
with torch.device("meta"):
new_module = nn.Linear(model.in_features, model.out_features)
new_module.weight = model.weight
new_module.bias = model.bias
model = new_module
and not need any changes to torchao/float8
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vkuzo what do you think about having dequantizing a model as a separate API? it feels a bit weird to have this logic in _replace_with_custom_fn_if_matches_filter
which is supposed to be a simple module replacement function I feel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my 5c -
- we can make it work now without adding any public APIs, with minimal increase in complexity
- if it's important to have a public API for "remove low precision training from a model", we can have that conversation in parallel
wdyt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the motivation for adding a new API is making the dequantizing step more explicit for user, instead of hide it in a module replacement function.
but agree this can happen in parallel. also it's probably not worth spending time to discuss as of now, and wait until there are more use cases might be better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to hear some motivation on why this needs a public API versus doing the same thing without a new API
8d1f189
to
a184759
Compare
a184759
to
7373ae7
Compare
7373ae7
to
dc1a233
Compare
cab447c
to
dfdba92
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
94655a9
to
f66e7ed
Compare
Added support for quantize_() API to work with models trained with float8, using Float8Linear.