-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Add custom CUDA tinygemm
unpacker
#415
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/415
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit e90e280 with merge base e5548b7 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
great, after this is landed we can replace this workaround code ao/torchao/dtypes/affine_quantized_tensor.py Line 505 in 9dc2c11
|
Took a look at Are there tests for Currently working on fusing dequant into the unpacking kernel, however, a simple sanity check using the same logic as That is, I'm using |
we don't have get_plain() tests yet, but I'm planning to add some tests for AffineQuantizedTensor in the future the way that tinygemm dequantize implmeneted is a bit different from the normal path, here is how it's implmeneted: function call to our primitive ops: ao/torchao/quantization/utils.py Line 370 in 37c348e
code path: ao/torchao/quantization/quant_primitives.py Line 266 in 37c348e
main difference is the zero_point is in floating point domain (while the the quant/dequant that we are more familiar with is in integer domain): ao/torchao/quantization/quant_primitives.py Lines 43 to 44 in 37c348e
|
Many thanks on the clarification! Helps explain why This is good to know as the original motivation for this PR was to help What is the mathematical derivation of the |
yes, the motivation for
for
I'm not aware of any formal papers or blogs. so the differences are shown in our quant_primitive ops in these two flags: ao/torchao/quantization/quant_primitives.py Lines 303 to 318 in c2cf973
traditional integer quantization:
tinygemm:
|
Thanks! Looking at scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1)) where Then in the tinygemm kernel, dequantization is performed as |
m.impl("torchao::unpack_int4_to_int", &_unpack_int4_to_int); | ||
m.impl("torchao::dequantize_int4", &_dequantize_int4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I feel we probably need to mention tensor_core_tiled
layout in the name of these ops if these are specific to that packing format
sorry I just wrote down the quantize function there, it's not the dequant function, I should probably add all algorithms (choose_qparams, quant, dequant) there. the dequant we are using is here: ao/torchao/quantization/quant_primitives.py Lines 267 to 274 in 37c348e
|
@jerryzh168 If we unpack what
Then in x = q * scales + zeros
= q * scales + min_val + scales * mid_point Where Comparing this to x = (q - mid_point) * scales + zeros
= q * scales - scales * mid_point + zeros Assuming x = q * scales - scales * mid_point + min_val + scales * mid_point
= q * scales + min_val How to reconcile these differences? How are |
yeah no problem, I think the main difference as you listed is this part:
I feel the I'm not very familiar with tinygemm kernel implementation itself, but I think this should be accounted for either by preprocess of also for some additional context, current quant primitives in torchao are adapted from the original gpt-fast/tinygemm choose_qparams/quantize/dequantize implemnetations and we have regression tests to make sure they match:
|
maybe related to https://github.com/pytorch/pytorch/blob/93a33bf3ac0b4c9560b49780eabcad2f76dcf43e/aten/src/ATen/native/cuda/int4mm.cu#L197 cc @HDCharles do you know how tinygemm kernel dequant implementation match up with the python dequant implementation? |
For the purposes of this PR, then, what should be the expected behavior of That is, given packed weights, scales, zeros, etc., what should be the calculation to dequantize the weights from Checking the quant_primitives |
@jeromeku OK I just confirmed with Jeff Johnson that this code is actually doing both a uint4 -> int4 conversion ([0, 15] --> [-8, 7]) which is equivalent to (q_val - mid_point) in our dequant code, and also a conversion to bfloat16 so I feel the for test, what you described make sense, I think you can do two test:
|
Thanks for the clarification. Will update the Would be good to add some additional documentation explaining the pre-processing / post-processing needed to use quantized weights, scales, and zero-points prepared using "conventional" ( |
sure thanks, I'll add some docs for quant_primitives in our README |
torch._check(scales_and_zeros.size(1) == N, lambda: "scales_and_zeros must have N at dim 1") | ||
torch._check(scales_and_zeros.size(2) == 2, lambda: "scales_and_zeros must have 2 at dim 2") | ||
|
||
return torch.empty((N, K), dtype=torch.bfloat16, device=packed_w.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same for this, is this supposed to call dequantize_tensor_core_tiled_layout
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought this was the expected pattern for registering a custom op? I was following the example of the pre-existing fp6_linear
custom op already in ops.py
.
Previously one would register an abstract impl
for composability with torch.compile
. Thought this was the expected interface with the new custom op registration API. That is, register a "fake" implementation that runs checks and just returns the expected shape of the output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh OK, I think I understand now, register_custom_op
is calling register_fake
/impl_abstract
, I feel we need to rename this util to something more accurate, cc @msaroufim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good overall, thanks for working on this @jeromeku! just had a few nits + requested additional tests and questions around the motivation of having two ops
Summary: att, per request in pytorch#415 (comment) Test Plan: doc changes Reviewers: Subscribers: Tasks: Tags:
Summary: att, per request in pytorch#415 (comment) Test Plan: doc changes Reviewers: Subscribers: Tasks: Tags:
Summary: att, per request in #415 (comment) Test Plan: doc changes Reviewers: Subscribers: Tasks: Tags:
Fixed all the above:
|
|
||
return torch.empty((N, K), dtype=torch.int32, device=packed_w.device) | ||
|
||
def dequantize_tensor_core_tiled_layout(packed_w: Tensor, scales_and_zeros: Tensor, group_size: int, inner_k_tiles: int) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this specific to uint4
btw?
looks like so, maybe we can add uint4
to the name as well in that case, unless this layout makes sense for other dtypes as well and we want to extend it to other dtypes in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! really appreciate adding this functionality and the thorough comments/testing!
Just a minor merge conflict and this should be good to merge |
Getting CI failure unrelated to PR:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool thank you for the awesome work @jeromeku and thank you for the thorough review @jerryzh168
The CI failure indeed seems unrelated, most likely a flake due to connection issues with HF
…#469) Summary: att, per request in pytorch#415 (comment) Test Plan: doc changes Reviewers: Subscribers: Tasks: Tags:
* add unpack cuda * add tests * fix tests * refactor tinygemm unpacking kernel * add dequant * add additional dequant check * update tinygemm dequantize test * correct dequant kernel logic * clean up kernel * update dequantize kernel tests * rename kernel ops to tensor_core_tiled_layout * add renamed kernel source * add back test_aot_dispatch opcheck * rename innerKTiles to inner_k_tiles * add unpack and dequant test * additional numerical checks for unpack then dequant * rebase test_ops on main * remove commented out code * skip dynamic opcheck unless torch>=2.5
* Revert "Revert "Embedding quantization per backend (pytorch#402)" (pytorch#411)" This reverts commit 8b35acdff4fded779799ab8a419e55f885dd8918. * 4b and 8b embedding table quantization * minor changes * remove extra et workflow
Description
Adds CUDA custom ops to unpack weights that have been packed with
torch.ops.aten._convert_weight_to_int4pack
for use withtorch.ops.aten._weight_int4pack_mm
.Currently there is only a packing function that permutes and prepacks the weights in tensor-core format. However, there is no equivalent unpacking function that reorders the weights back to the original logical layout.
The implementation is an adaptation of the original packing code (
int4mm.cu
) with modifications to simplify indexing logic and fused unpacking & dequantization.Motivation
Fast unpacking of packed weights is needed when switching quantized gemm backends during inference.
As workloads transition from memory-bound to compute-bound (i.e., context length growth during decoding), users might wish to switch to a different kernel implementation that is more performant in this regime than
tinygemm
.In order to do this, the weights need to be unpacked from the packed format. Alternative would be to store 2 copies of the weights -- one packed, one in logical format -- but this is clearly not ideal given memory burden.
Features
Add 2 custom CUDA ops, registered per the instructions in
torchao
custom op documentation:torchao.ops.unpack_int4
- unpacks the packed weight to the originalN x K
logical layout with dtypetorch.int
. Can be used within TensorCoreTiledAQTLayout.get_plain to recover original layout of the (quantized) tensor.torchao.ops.dequantize_int4
- dequantizes the packed weight tobfloat16
tensor with originalN x K
logical layout. This is useful for developers who want to unpack and dequantize the packed weight when switching quantized matmul backends on the fly.Tests
Tests have been added to
test/test_ops.py
for both correctness as well as for correct custom op registration.TODO
Kernel works against a reference implementation per my understanding of dequantization but needs further verification (see notes intest/test_ops.py:test_dequant_int4_correctness
)dequantize
forZeroPointDomain.Float
per updatetest_aot_dispatch_dynamic
opcheck
failureAQT
get_plain
@msaroufim