-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Float8] Make Inference and Training code independent #808
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/808
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c5c333e with merge base f5703b0 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
71bc491
to
ec334af
Compare
ec334af
to
9583709
Compare
9583709
to
07b71e4
Compare
07b71e4
to
ca125ea
Compare
|
||
Note: | ||
If the input tensor is already in Float8 format, it is returned as is without re-casting. | ||
class Float8MMConfig(NamedTuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make the names ScaledMMConfig
and Float8MMConfig
not be confusing? Ideally it should be clear why there are two objects and in which context the user should use which object
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I was going to do ScaledMMConfigInference but that is even too verbose for me, any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for training I went with a user facing Float8GemmConfig
which is a dataclass and easy to explain (IMO), the way the data gets passed around is not user facing so details of ScaledMMConfig
don't matter as much
tbh making ScaledMMConfig
public and reusable between training + inference sounds right to me, if you really want this to be public without a dataclass wrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can not be a dataclass today because that wont work with compile, Lazos has a PR to make frozen dataclasses proxyable but until then it has to be a named tuple.
I love dataclasses but I dont understand why they are more understandable then a namedtuple, Is the problem that this name is similar to other names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's less about dataclass versus named tuple, and more about the field readability / understandability / future proofness, etc. For example, if we made ScaledMMConfig
have an output dtype enum instead of a boolean and ensured all the other args are consistent with other public APIs, I think that would be fine.
https://stackoverflow.com/a/18348004/1058521 is one minor reason, default values look nicer with dataclasses
stack-info: PR: #808, branch: drisspg/stack/9
ca125ea
to
c5c333e
Compare
Stacked PRs:
Summary
Remove Old FP8 inference flow and and switch over to Float8MMConfig
in favor of the newly added support to AQT + quantize_ apis. It also completely separates and utilization on training code by creating a new
Float8MMConfig
and a addmm wrapper in the inference.py file