Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SmoothQuant using tensor subclassing #1030

Merged
merged 49 commits into from
Oct 23, 2024
Merged

Conversation

Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen Xia-Weiwen commented Oct 8, 2024

This PR implements SmoothQuant with tensor subclassing.
SmoothQuant is similar to AWQ with the following differences:

  • Matmul is computed in int8 instead of floating point (at least at op level)
  • The smoothing factor is calculated differently from the equalization scales of AWQ

It provides the following API for quantization:

  • insert_smooth_quant_observer_ inserts observers into the model in-place.
  • smooth_quant applies SmoothQuant to each linear layer of the model. Use it by calling torchao.quantization.quantize_.

Two more APIs for quantization recipe tuning for advanced users:

  • save_smooth_quant_recipe saves smoothing factors and quantization parameters of an observed model to a JSON file.
  • load_smooth_quant_recipe loads these parameters from a file to a model with observers inserted.

More details can be found in the torchao/prototype/smoothquant/README.md.

Unit tests are added in test/prototype/test_smoothquant.py
An example is provided in torchao/prototype/smoothquant/example.py

Copy link

pytorch-bot bot commented Oct 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1030

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cb9167a with merge base d4b2f33 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 8, 2024
return insert_subclass


def save_smooth_quant_recipe(model: torch.nn.Module, save_path: str) -> Dict[str, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? or just saving the state_dict for observed model is enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to have an API to modify (tune) quantization parameters, i.e. the recipe here. Do you have any concern about adding this API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the state_dict is supposed to be used by other APIs to tune quantization parameters? I think that's fine if you have this use case in mind, is the model with SmoothQuantObservedLinear not serializable by itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SmoothQuantObservedLinear is serializable. However, a recipe is more flexible to tune parameters. Thanks.

@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 I added a new tensor subclass LinearActivationScaleQuantizedTensor to support x -> x/scale -> quantize x for torch.compile.

If I use LinearActivationQuantizedTensor, the x/sacle is done outside the class (by input_quant_func) and there is a dynamo error about scale during torch.compile. I guess it's because the scale tensors are not on the graph in this case. Putting the scale in the weight tensor solves the problem. And WeightTensorWithLinearActivationScaleMetadata does not quantize activation.

Do you have any concern adding this new class? Thanks.


## Benchmark
Running the example with `torch.compile` on a NVIDIA A10G GPU.
### meta-llama/Llama-2-7b-hf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also report the perplexity for dynamic and static quant without using smoothquant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I have collected more data and provided in the README

)


class SmoothQuantObserver(AffineQuantizedObserverBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that smooth quant observer is not related to AQT, probably don't need to inherit from AffineQuantizedObserverBase?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. It's changed.

Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM.

torchao/prototype/smoothquant/core.py Outdated Show resolved Hide resolved
@Xia-Weiwen Xia-Weiwen marked this pull request as ready for review October 22, 2024 12:50
@Xia-Weiwen Xia-Weiwen requested a review from jerryzh168 October 22, 2024 12:50
@Xia-Weiwen Xia-Weiwen changed the title [WIP] SmoothQuant using tensor subclassing SmoothQuant using tensor subclassing Oct 22, 2024
@@ -1369,7 +1368,10 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))
x_scales_dtype = x_scales.dtype
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment on how do we decide the intermediate dtype here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Added.

| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* |
|-|-|-|-|-|
| Dynamic | 21.2475 | 8.8288 | 9.6514 | 8.3574 |
| Static | 301.7118 | 18.0617 | 10.8343 | 278.9819 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so looks like it's more effective on static quant

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you also do a sanity check for perf to make sure this doesn't regress performance?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance, it's found from high to low

  • Static quant
  • Static quant + SmoothQuant
  • Dynamic quant
  • Dynamic quant + SmoothQuant

It's expected that SmoothQuant is slower because it inserts div on graph. Is it Ok?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's fine as long as it's reasonable, it's just a sanity check

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks!

@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168 I have updated this PR. Please take a look again. Thanks

@Xia-Weiwen
Copy link
Collaborator Author

BTW, I found torchao's observer behaves differently from pytorch's observer when running on cuda. torchao's observer has its self.min_val and self.max_val on the same device as input tensor but pytorch's observer always has them on cpu. Is that something that needs a fix? Thanks.

@jerryzh168
Copy link
Contributor

BTW, I found torchao's observer behaves differently from pytorch's observer when running on cuda. torchao's observer has its self.min_val and self.max_val on the same device as input tensor but pytorch's observer always has them on cpu. Is that something that needs a fix? Thanks.

I see, I feel min_val/max_val being in the same device as input makes more sense? or are you saying we should add an option here?

@jerryzh168 jerryzh168 merged commit 629aee1 into pytorch:main Oct 23, 2024
17 checks passed
@Xia-Weiwen
Copy link
Collaborator Author

BTW, I found torchao's observer behaves differently from pytorch's observer when running on cuda. torchao's observer has its self.min_val and self.max_val on the same device as input tensor but pytorch's observer always has them on cpu. Is that something that needs a fix? Thanks.

I see, I feel min_val/max_val being in the same device as input makes more sense? or are you saying we should add an option here?

Oh, I thought you might want them to have the same behavior. It's alright if that is not an issue.

jainapurva pushed a commit that referenced this pull request Oct 24, 2024
* SmoothQuant using tensor subclassing

* Update UT

* Add SmoothQuant example

* Remove duplicate implementation of int_scaled_matmul for CPU

* Update example.py

* Remove unused code

* Implement with LinearActivationQuantizedTensor

* Fix load/save

* Fix device mismatch in observer

* Fix fp16 overflow issue in int_scaled_matmul

* Add linear_activation_scale_quantized.py for torch.compile

* Quantize act/wei to 7 bit on old CPU platforms

* Fix device mismatch

* Fix UT failures

* Fix UT

* Don't use torch._int_mm for CPU now because it may overflow

* Remove reduce_range

* Refine code

* Remove torch.compile from example

* Add torch.compile in example

* Debug CI failures

* Debug CI failures (1)

* Debug CI failures (2)

* Debug CI failures (3)

* Work with torch.compile

* Update torchao/kernel/intmm.py

* Update readme.md

* Update readme.md

* Debug CI failures (4)

* Reimplement with nested tensor subclassing

* Test torch.compile only with PyTorch >= 2.5

* Debug CI failures (5)

* Debug CI failures (6)

* Debug CI failures (7)

* Use MovingAvg observer for activation; Update UT and readme

* Revert changes to test_spinquant.py; refine readme

* Debug CI failures (8)

* Debug CI failures (9)

* Fix CI failure

* Refactor SmoothQuantObserver

* Rename readme.md -> README.md

* Rename insert_smooth_quant_observer -> insert_smooth_quant_observer_ to indicate inplace

* Fix device mismatch in observer

* Fall back to conventional quantization if alpha is None

* Update README.md to provide more benchmark data; fix CI

* Fix CI failures

* Add a comment in affine_quantized_tensor.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants