-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
deduplicate code for some torchao q/dq ops #173
Conversation
68dc554
to
28bd36c
Compare
d181692
to
b969c58
Compare
Do you want to try applying this more APIs in one PR? I think this function isn't used in a torch.compile context so we can't get signal whether the new API works well with compile. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
OK sure, I can apply this to the rest tomorrow |
get_group_qparams_symmetric
if zero_point is not None: | ||
y -= zero_point | ||
return y * scale | ||
eps = torch.finfo(torch.float32).eps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems unused. Just in case you meant to use it. We could add a linter to CI to help catch this, but is not super important at the moment. I'll add it for the list for 0.3.
quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) | ||
eps = torch.finfo(torch.float32).eps | ||
block_size = (1, x.shape[1]) | ||
scale_dtype = torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also tie this to x.dtype
? Previous code did scale = torch.clamp(scale, min=eps).to(x.dtype)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe, I'm not sure if we'll have use cases that has a different dtype though, maybe I can make this default to x.dtype instead of torch.float32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure. If we always expect float32 then we could add an assert just so it doesn't fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just changed this to x.dtype in choose_qparams_affine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eps has to be float32's eps to pass the tests, I guess we could change later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see comments around use sites. I also it's worthwhile to compare the output of TORCH_LOGS='output_code'
for one or two of these to see if the resulting code is still fused.
@cpuhrsch I'll need to fix the CI first btw, but thanks for the review |
do we have performance benchmarks for these things? |
@jerryzh168 - mostly within other repositories (aside from https://github.com/pytorch/ao/tree/739e62d197b25d40422fe23fad3df2c7d2efb9d7/tutorials/quantize_vit). But if the refactor here ends up generating the same code, it should perform the same way. We can optimize more after. |
verified quantize_vit gives the same result: https://www.diffchecker.com/DoqCSkRC/ |
Summary: This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
zero_point_dtype = torch.int32 | ||
|
||
qscheme_to_mapping_type = { | ||
torch.per_tensor_affine: MappingType.ASYMMETRIC, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very, very nit: Hm, I wondering if MappingType
is the right name... - We can definitely do this in a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so MappingType means how we map from floating point to quantized values. I'm open to other suggestions as well. although we may remove this and just split the function into two in the future, so we could discuss this a little bit later (after we verified this with executorch)
@@ -218,8 +218,11 @@ def dequantize(self, dtype=None): | |||
""" | |||
Obtain the dequantized version of the quantized tensor subclass | |||
""" | |||
zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised this didn't cause a regression. Seems like a big change.
Summary: This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Co-authored-by: cpuhrsch <[email protected]>
Summary:
This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code
get_group_qparams_symmetric
dynamically_quantize_per_tensor
dynamically_quantize_per_channel
dequantize_per_tensor
dequantize_per_channel
Note that there are some tinygemm specific ops that calcualtes zero_point in float domain, we could think about how to replace later, e.g. we can have a flag to indicate whether we calculate zero_point in float domain or quantized domain
Test Plan:
CI
Reviewers:
Subscribers:
Tasks:
Tags: