Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deduplicate code for some torchao q/dq ops #173

Merged
merged 2 commits into from
Apr 29, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Apr 24, 2024

Summary:
This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code

  • get_group_qparams_symmetric
  • dynamically_quantize_per_tensor
  • dynamically_quantize_per_channel
  • dequantize_per_tensor
  • dequantize_per_channel

Note that there are some tinygemm specific ops that calcualtes zero_point in float domain, we could think about how to replace later, e.g. we can have a flag to indicate whether we calculate zero_point in float domain or quantized domain

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 24, 2024
@jerryzh168 jerryzh168 force-pushed the dedup-1 branch 2 times, most recently from 68dc554 to 28bd36c Compare April 24, 2024 23:58
@jerryzh168 jerryzh168 force-pushed the dedup-1 branch 4 times, most recently from d181692 to b969c58 Compare April 25, 2024 00:55
@cpuhrsch
Copy link
Contributor

Do you want to try applying this more APIs in one PR? I think this function isn't used in a torch.compile context so we can't get signal whether the new API works well with compile.

Copy link

@mikekgfb mikekgfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@jerryzh168
Copy link
Contributor Author

Do you want to try applying this more APIs in one PR? I think this function isn't used in a torch.compile context so we can't get signal whether the new API works well with compile.

OK sure, I can apply this to the rest tomorrow

@jerryzh168 jerryzh168 changed the title deduplicate code for get_group_qparams_symmetric deduplicate code for some torchao q/dq ops Apr 25, 2024
if zero_point is not None:
y -= zero_point
return y * scale
eps = torch.finfo(torch.float32).eps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unused. Just in case you meant to use it. We could add a linter to CI to help catch this, but is not super important at the moment. I'll add it for the list for 0.3.

quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype)
eps = torch.finfo(torch.float32).eps
block_size = (1, x.shape[1])
scale_dtype = torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also tie this to x.dtype? Previous code did scale = torch.clamp(scale, min=eps).to(x.dtype).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, I'm not sure if we'll have use cases that has a different dtype though, maybe I can make this default to x.dtype instead of torch.float32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. If we always expect float32 then we could add an assert just so it doesn't fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just changed this to x.dtype in choose_qparams_affine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eps has to be float32's eps to pass the tests, I guess we could change later

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see comments around use sites. I also it's worthwhile to compare the output of TORCH_LOGS='output_code' for one or two of these to see if the resulting code is still fused.

@jerryzh168
Copy link
Contributor Author

@cpuhrsch I'll need to fix the CI first btw, but thanks for the review

@jerryzh168
Copy link
Contributor Author

Please see comments around use sites. I also it's worthwhile to compare the output of TORCH_LOGS='output_code' for one or two of these to see if the resulting code is still fused.

do we have performance benchmarks for these things?

@cpuhrsch
Copy link
Contributor

@jerryzh168 - mostly within other repositories (aside from https://github.com/pytorch/ao/tree/739e62d197b25d40422fe23fad3df2c7d2efb9d7/tutorials/quantize_vit). But if the refactor here ends up generating the same code, it should perform the same way. We can optimize more after.

@jerryzh168
Copy link
Contributor Author

verified quantize_vit gives the same result: https://www.diffchecker.com/DoqCSkRC/

@jerryzh168 jerryzh168 requested a review from cpuhrsch April 29, 2024 20:11
@jerryzh168 jerryzh168 dismissed cpuhrsch’s stale review April 29, 2024 20:12

addressed comments

Summary:
This just removes the implementation, we can have follow up PRs to remove the call all together
after we have replaced all implementation with the new blockwise quant code

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
zero_point_dtype = torch.int32

qscheme_to_mapping_type = {
torch.per_tensor_affine: MappingType.ASYMMETRIC,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very, very nit: Hm, I wondering if MappingType is the right name... - We can definitely do this in a follow up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so MappingType means how we map from floating point to quantized values. I'm open to other suggestions as well. although we may remove this and just split the function into two in the future, so we could discuss this a little bit later (after we verified this with executorch)

@jerryzh168 jerryzh168 merged commit 6bcf244 into pytorch:main Apr 29, 2024
13 checks passed
@@ -218,8 +218,11 @@ def dequantize(self, dtype=None):
"""
Obtain the dequantized version of the quantized tensor subclass
"""
zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised this didn't cause a regression. Seems like a big change.

dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Summary:
This just removes the implementation, we can have follow up PRs to remove the call all together
after we have replaced all implementation with the new blockwise quant code

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: cpuhrsch <[email protected]>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants