Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for quantize_() with Float8Linear module #1344

Merged
merged 6 commits into from
Nov 28, 2024

Conversation

jainapurva
Copy link
Contributor

Added support for quantize_() API to work with models trained with float8, using Float8Linear.

@jainapurva jainapurva added the topic: bug fix Use this tag for PRs that fix bugs label Nov 26, 2024
Copy link

pytorch-bot bot commented Nov 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1344

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit f66e7ed with merge base ed76e9c (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 26, 2024
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think makes sense, can you add some tests

Some alternatives would be to maintain the Float8Linear structure and then swap the weight class dtype (feels weird).

I think this is good argument for subclasses since you can maintain structure and then assert all low_p subclasses have a dequant method and call that which will convert to fp32. I know Christian wants tensor.to(torch.float32) to do this but I think its tooo magical

cc @vkuzo

@@ -222,6 +224,9 @@ def _replace_with_custom_fn_if_matches_filter(
Returns:
None
"""
# If model is Float8Linear, convert it to Linear before moving forward
if isinstance(model, Float8Linear):
model = dequantize_float8_training(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just move your code snippet from the other file here:

if isinstance(model, Float8Linear):
    with torch.device("meta"):
        new_module = nn.Linear(model.in_features, model.out_features)
    new_module.weight = model.weight
    new_module.bias = model.bias        
    model = new_module

and not need any changes to torchao/float8?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo what do you think about having dequantizing a model as a separate API? it feels a bit weird to have this logic in _replace_with_custom_fn_if_matches_filter which is supposed to be a simple module replacement function I feel.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my 5c -

  1. we can make it work now without adding any public APIs, with minimal increase in complexity
  2. if it's important to have a public API for "remove low precision training from a model", we can have that conversation in parallel

wdyt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the motivation for adding a new API is making the dequantizing step more explicit for user, instead of hide it in a module replacement function.

but agree this can happen in parallel. also it's probably not worth spending time to discuss as of now, and wait until there are more use cases might be better

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to hear some motivation on why this needs a public API versus doing the same thing without a new API

@jainapurva jainapurva force-pushed the fp8_linear_quantize branch 2 times, most recently from 8d1f189 to a184759 Compare November 27, 2024 00:23
@jainapurva jainapurva marked this pull request as ready for review November 27, 2024 23:09
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jainapurva jainapurva merged commit c45d975 into main Nov 28, 2024
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants