Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix int4pack_mm error #517

Merged
merged 6 commits into from
Jul 29, 2024
Merged

Conversation

yanbing-j
Copy link
Contributor

Need update meta shape in PyTorch first pytorch/pytorch#130915.

Copy link

pytorch-bot bot commented Jul 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/517

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8aadb7d with merge base afde175 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 17, 2024
@svekars svekars requested review from msaroufim, jerryzh168 and andrewor14 and removed request for andrewor14, msaroufim and jerryzh168 July 17, 2024 15:26
andrewor14 added a commit that referenced this pull request Jul 17, 2024
int4 tinygemm quantization is currently broken in master and
being fixed in #517. Let's
skip these tests for now until that is fixed.
andrewor14 added a commit that referenced this pull request Jul 17, 2024
int4 tinygemm quantization is currently broken in master and
being fixed in #517. Let's
skip these tests for now until that is fixed.
andrewor14 added a commit that referenced this pull request Jul 17, 2024
int4 tinygemm quantization is currently broken in master and
being fixed in #517. Let's
skip these tests for now until that is fixed.
@yanbing-j yanbing-j force-pushed the yanbing/fix_int4_woq branch 2 times, most recently from 49b47a2 to a11e455 Compare July 19, 2024 03:48
@andrewor14 andrewor14 requested a review from HDCharles July 19, 2024 14:25
@@ -349,6 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max = 2 ** n_bit - 1

int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
Copy link
Contributor

@manuelcandales manuelcandales Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should break on MPS backend, since __lshift__.Scalar is not currently implemented for MPS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is int_data in MPS device in this function? If so, we can make int_data in cpu device, then convert back to MPS device.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet landed pytorch/pytorch#131813, so this won't be a problem anymore

Copy link
Contributor

@manuelcandales manuelcandales Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I learned from @malfet today (see his suggestion on line 203) that if instead of using << in here, we use torch.bitwise_left_shift(x, 4), it would be falling back to cpu. So, things would work even prior to his PR having landed, if torch.bitwise_left_shift is used instead of <<

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification. With pytorch/pytorch#131813, __lshift__.Scalar has MPS dispatch now.

@msaroufim
Copy link
Member

msaroufim commented Jul 25, 2024

@yanbing-j what's the status on this PR? If a breaking change requires more than 1 week of work to figure out on our end the right solution is to revert the offending PR

@yanbing-j
Copy link
Contributor Author

@msaroufim This PR is pending on pytorch/pytorch#130915, which is blocked by the RuntimeError: CUDA error: invalid device function when using OpInfo.
After pytorch/pytorch#130915 is merged into PyTorch, current PR can fix int4 error in torchao.

@yanbing-j
Copy link
Contributor Author

@msaroufim I update pytorch/pytorch#130915 not to use OpInfo.

Copy link

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[EDIT] Please ignore, both CUDA and MPS change will land at the same time

@@ -349,6 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max = 2 ** n_bit - 1

int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
int_data = (torch.bitwise_left_shift(int_data[::, ::2], 4) | int_data[::, 1::2]).to(torch.uint8)

@@ -198,6 +199,8 @@ def hqq_quants_to_torch_quants(
.reshape(shape)
.contiguous()
)
if TORCH_VERSION_AFTER_2_5:
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
W_q = (torch.bitwise_left_shift(W_q[::, ::2], 4) | W_q[::, 1::2]).to(torch.uint8)

@yanbing-j yanbing-j force-pushed the yanbing/fix_int4_woq branch from ecd2a86 to 5f41c1e Compare July 28, 2024 02:57
@msaroufim
Copy link
Member

Hi @yanbing-j just a heads up since I haven't seen CI be green, we're planning a release on Friday Aug 8 and doing a codefreeze on Friday Aug 2 so if this PR can't be landed in by this Wednesday I will have no choice but to revert your changes in core since this is a feature we have customers depend on such as https://github.com/mobiusml/hqq

@yanbing-j
Copy link
Contributor Author

@msaroufim Thanks for the information. Could you please start this CI again? Thanks!

@yanbing-j
Copy link
Contributor Author

@msaroufim @jerryzh168 I find pytorch/pytorch@6de65d5 will break test_int8_weight_only_quant_subclass and test_int4_weight_only_quant_subclass_api. Today's nightly can work, but tomorrow's will not.
Test plan:
python test/integration/test_integration.py -k test_int8_weight_only_quant_subclass_api
python test/integration/test_integration.py -k test_int4_weight_only_quant_subclass_api

@yanbing-j yanbing-j force-pushed the yanbing/fix_int4_woq branch from f03a014 to 8aadb7d Compare July 29, 2024 06:37
@msaroufim
Copy link
Member

Thanks @yanbing-j!

pytorch/pytorch@6de65d5 was reverted so indeed should only see breakages for 1 day

@msaroufim msaroufim merged commit 8fa11a6 into pytorch:main Jul 29, 2024
13 checks passed
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
int4 tinygemm quantization is currently broken in master and
being fixed in pytorch#517. Let's
skip these tests for now until that is fixed.
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* Fix int4pack_mm error

* fix CI

* Fix CI

* Fix CI

* Fix CI

* Fix CI
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Update iOS.md

* Update iOS.md
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <[email protected]>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <[email protected]>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* code beautification

* code beautification, move functions together

* make --device fast the default (pytorch#515)

* make --device fast the default

* Update iOS.md (pytorch#517)

* Update iOS.md

* Update iOS.md

* Pip to pip3 (pytorch#504)

* remove macos-12 test

* pip to pip3

* break aoti CI jobs separately (pytorch#500)

* init

* fixes

* more fixes

* fixes

* fix

* fix

* bug fix

* add objcopy update

* suppress int8

* undefined variable

---------

Co-authored-by: Michael Gschwind <[email protected]>

* Support llama3 in chat in run.cpp  (pytorch#486)

* refactor chat runner in preparation for llama3

* add sketch for llama3 prompt template and move to returning tokens

* fix tiktoken

* fixes to chat

* add default llama_ver

* Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519)

* remove code for no KV Cache path (pytorch#527)

* Update ADVANCED-USERS.md (pytorch#529)

Update Advanced Users description to reflect changes in the repo since the description was initially created.

* runner-aoti on cuda (pytorch#531)

* runner-aoti on cuda

* transfer results back to CPU

* transfer results back to CPU

* runner-aoti on cuda

* Update runner_build.md (pytorch#530)

Update description of runner and build process in runner_build.md

* clean up runner code a little (pytorch#532)

* clean up runner code a little

* update

* update

* pull out generate loop in chat

* updates

* edit docs

* typo

* move int8 linear class and function into qops.py (pytorch#534)

* add dtype tests for runner-aoti + runner-et (pytorch#539)

* add dtype tests for runner-aoti + runner-et

* typo

* Quantized embedding (pytorch#536)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* Move Linear int4 to qops (pytorch#537)

* move int8 linear class and function into qops.py

* move Quantized Embedding to qops.py

* move int4 linear to qops

* Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548)

This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1.

* fix generate for llama3 (pytorch#538)

* fix generate for llama3

* switch more things to C

* remove C++ header

* add delegation visualization instructions (pytorch#551)

* Add dtype runner aoti (pytorch#552)

* add dtype tests for runner-aoti + runner-et

* typo

* add dtype test runner-aoti

* test sdpa with fp16 (pytorch#553)

* test sdpa with fp16

* kv cache fp32

* typo

* update (pytorch#560)

* Only support newest versions of lm-eval (pytorch#556)

Summary:
remove support for lm-eval 0.3 to reduce the options we have

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:

* split cpu eval CI by dtype (pytorch#554)

* split cpu eval CI by dtype

* fix

* differentiate names with checks

* keep one name the same as old

* fix

* Removing duplicate HF issue message from README (pytorch#559)

Co-authored-by: Michael Gschwind <[email protected]>

* doc updates (pytorch#567)

* Add VM-safe MPS check

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>

* add unpacking support (pytorch#525)

* add unpacking support

* fix typos and linter

* perform parallel prefill when possible (pytorch#568)

* perform parallel prefill when possible

* typo

* disable hack

* remove print

* remove debug messages which prevent export

* fixes

* stream results in generate.py (pytorch#571)

* remove logging interfering with export

---------

Co-authored-by: Anthony Shoumikhin <[email protected]>
Co-authored-by: metascroy <[email protected]>
Co-authored-by: Nikita Shulga <[email protected]>
Co-authored-by: lucylq <[email protected]>
Co-authored-by: Jerry Zhang <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants