Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Add per block quantization primitives #159

Merged
merged 3 commits into from
Apr 24, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Apr 22, 2024

Summary:
We want to use this to replace all quantize/dequantize/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py

Note: this PR only adds the ops, we'll do replacement in separate PRs and make sure it does not degrade the performance or accuracy

Test Plan:
python test/quantization/test_quant_primitives.py

Reviewers:

Subscribers:

Tasks:

Tags:

@jerryzh168 jerryzh168 requested a review from cpuhrsch April 22, 2024 23:46
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2024
@jerryzh168 jerryzh168 force-pushed the dedup branch 2 times, most recently from a02d061 to 794d9b5 Compare April 23, 2024 00:14
return shape_for_reduction, reduction_dims


def quantize_affine_per_block(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be easier to first write a version of this that assumes

input.dim() == len(block_size) == scale.dim() == zero_point.dim() and then use various tools to implement broadcasting. But our ops should kind of imply the broadcasting here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch any ideas on how we can support broadcasting for the example I described here: #159 (comment)

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Is the next step to rewrite existing quant primitives using this? How will this work for qdq ops currently living in pytorch?

torch.uint7: (0, 2**7-1),
})

def _get_qmin_qmax(dtype, quant_min, quant_max):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is more like _check_qmin_qmax? Alternatively make quant_min and quant_max default to None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this actually combined two functions, i can split them as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are thinking about just don't allow quant_min/quant_max, will update this after we made a decision here, I'll add a TODO here

zero_point = zero_point.view(shape_after_reduction)

quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is slightly different from existing quant primitives, e.g. quantize_per_token does round after adding zp: https://github.com/pytorch/pytorch/blob/d40774f4ed4a45c70d49e66f4e1f197dfc274758/torch/ao/quantization/fx/_decomposed.py#L771

However, as written this is consistent with the existing torch.fake_quantize_per_channel_affine, which adds the zp after round. Which one do we want to follow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just choose one, we can start with this and check to see if we can adjust others to use this I think, we could make the dtypes more explicit as well

)

self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this tell you how many elements were different and by how much? Should we use this instead?

torch.testing.assert_close(quantized, quantized_ref, atol=0, rtol=0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is equal actually

torchao/quantization/quant_primitives.py Show resolved Hide resolved
torchao/quantization/quant_primitives.py Outdated Show resolved Hide resolved
Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepting now with the intent of iterating on this over time

@jerryzh168
Copy link
Contributor Author

Looks great! Is the next step to rewrite existing quant primitives using this? How will this work for qdq ops currently living in pytorch?

yeah we'll rewrite existing primitives in torchao to use this first, and then expand to pytorch later, we'll need to move the ops to pytorch in order to refactor the ops there

@jerryzh168
Copy link
Contributor Author

I also tried to not include block_size as args and use keep_dim for scales, but there was a problem, e.g. when we have:
input: (3, 3, 10, 10)
block_size: (3, 3, 2, 10)

scale size: (1, 1, 5, 1)

I'm not sure how can we broadcast the scale to be size (1, 1, 10, 1) in order for it to be divided by input

Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these look significantly more complicated than the old code so i'm wondering if we can still torch.compile them into performant kernels?

Would like to see some ad-hoc microbenchmarks at least to indicate that we're not immediately going to see a huge perf hit from this change, at the very least for per token symmetric quant.

@jerryzh168
Copy link
Contributor Author

these look significantly more complicated than the old code so i'm wondering if we can still torch.compile them into performant kernels?

Would like to see some ad-hoc microbenchmarks at least to indicate that we're not immediately going to see a huge perf hit from this change, at the very least for per token symmetric quant.

@HDCharles this is just a starting point, I'm planning to replace the existing ops in separate PRs and we can make improvement at that time, including making sure perf is good etc. does that sounds good?

Summary:
We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py

Test Plan:
python test/quantization/test_quant_primitives.py

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 merged commit f05c215 into pytorch:main Apr 24, 2024
13 checks passed
@jerryzh168 jerryzh168 deleted the dedup branch April 24, 2024 23:21
reduction_dims = []
cur_dim = 0
for i in range(len(block_size)):
if block_size[i] != input_size[i] and block_size[i] > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is block_size[i] != input_size[i] and block_size[i] == 1. As in if corresponding block size is 1. What would that mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_size[i] == 1 means that for ith dimension, each slice will have their own qparams

reduction_dims.append(cur_dim + 1)
cur_dim += 2
else:
# block_size[i] == input_size[i] or block_size[i] == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i see it here


def quantize_affine(
input: torch.Tensor,
block_size: List[int],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we decided that you can use scale/zero point shape to infer htis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is some issues with broadcasting: #159 (comment) let me know if you have some ideas

"""
# TODO: validations
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should you also validate that the blocksize should also correspond to the scale/zp size?

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me just add a TODO for now, so we don't over complicate the code, maybe we could remove some of the shape code if broadcasting is working in the future

dequant = input.to(torch.float32)
scale = scale.to(torch.float32)
if zero_point is not None:
zero_point = zero_point.to(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not feel accurate. I think we have had some discussion around this that it should be `(input.to(torch.int32) - zero_point.to(torch.int32)).to(torch.float32) * scale)

Copy link
Contributor Author

@jerryzh168 jerryzh168 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, makes sense, will fix

quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests where you expect exceptions thrown

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for reviewing htis late so just leaving some comments. Mainly two

  • Do we need block size or it can be drived?
  • My concern around dequantize routine

@jerryzh168
Copy link
Contributor Author

Sorry for reviewing htis late so just leaving some comments. Mainly two

  • Do we need block size or it can be drived?
  • My concern around dequantize routine

sorry just saw the comments

  • yeah blocksize can be derived from some helper functions I think, we could add these when we start replacing callsites I feel
  • will fix

@kimishpatel
Copy link
Contributor

I also tried to not include block_size as args and use keep_dim for scales, but there was a problem, e.g. when we have: input: (3, 3, 10, 10) block_size: (3, 3, 2, 10)

scale size: (1, 1, 5, 1)

I'm not sure how can we broadcast the scale to be size (1, 1, 10, 1) in order for it to be divided by input

input : (3, 3, 10, 10)
scale: (1, 1, 5, 1)

It we assume scale is always of valid shape than to broadcast scale[2] == 5 to input[2]==10, we will have to interpret scale as blockwise scale where block size = 2., Right?

dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
…ch#159)

* move loading of modelfor inference into _load_inference_model

* type

* load_inference_model

* load_inference_model

* typo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants