-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add experimental INT8 quantized training #644
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/644
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4924e8d with merge base 0b66ff0 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
from torchao.quantization.quant_api import quantize_ | ||
|
||
model = ... | ||
quantize_(model, int8_weight_only_quantized_training()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah this is nice
optim_int8.zero_grad() | ||
|
||
|
||
class TestFSDP2(FSDPTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it feels like we could probably add this as part of standard test suite, that we can use to sanity check if FSDP is supported for any dtype/tensor subclasses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @gau-nernst, looks great! One question I have is do the forward numerics match the inference path? E.g.
quantize_(model, int8_weight_only()) # vs
quantize_(model, int8_weight_only_quantized_training())
Do we want them to match?
Also, for Llama2, do you have the eval accuracies/perplexities compared to bf16 training?
tensor = tensor.round() | ||
|
||
# NOTE: is clipping necessary? | ||
tensor = tensor.clip(-128, 127).to(torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually is it -127 or -128? From your PR description it says -127
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this clipping is to ensure that it is within INT8 range, not related to the scaling step before this. By right, for stochastic rounding to work correctly, clipping shouldn't occur (which bias the results), that's why I add the comment there.
# don't do anything. workaround for FSDP2. might give unexpected or wrong results. | ||
@Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default]) | ||
def _(func, types, args, kwargs): | ||
out = Int8QTLinearWeight(args[0].int_data, args[0].scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we don't need to do anything here? What happens if we call view on the inner tensors and return a copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These ops are called by FSDP2 by these lines (there are a few mores, you can ctrl+F, but these are the first hits when FSDP2 runs)
https://github.com/pytorch/pytorch/blob/cd565bc45554f6c7d5659bff19a270bda58ce71d/torch/distributed/_composable/fsdp/_fsdp_param.py#L349
https://github.com/pytorch/pytorch/blob/cd565bc45554f6c7d5659bff19a270bda58ce71d/torch/distributed/_composable/fsdp/_fsdp_param.py#L466-L471
From what I understand, FSDP2 expects a "view" here i.e. not a copy. Since we do channel-wise quantization here, a view doesn't quite make sense since it changes the dimensions -> quantization scale should change too (i.e. group size is tied to tensor dims). Propagate the ops to inner tensor may not make much sense too. For example, as_strided
can't be called on .scale
, since scale has different shape from the outer tensor.
I looked into NF4 impl, and it seems like they also keep the inner tensors as is, only update the kwargs passing to Tensor._make_wrapper_subclass()
, which change the outer tensor appearance.
ao/torchao/dtypes/nf4tensor.py
Lines 230 to 270 in e7fc0ed
@implements( | |
[ | |
aten.view.default, | |
] | |
) | |
@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") | |
def nf4_view(aten_op, args, kwargs=None): | |
nf4tensor = args[0] | |
size = args[1] | |
if size[0] != -1: | |
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") | |
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) | |
updated_attrs.update({ | |
"size": [nf4tensor.numel()], | |
"stride": (1, ), | |
}) | |
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs)) | |
@implements( | |
[ | |
aten.as_strided.default, | |
] | |
) | |
@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") | |
def nf4_as_strided(aten_op, args, kwargs=None): | |
nf4tensor = args[0] | |
size = args[1] | |
stride = tuple(args[2]) | |
storage_offset = args[3] | |
if math.prod(size) != nf4tensor.numel(): | |
raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") | |
if stride != make_contiguous_strides_for(size): | |
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}") | |
if nf4tensor.storage_offset() != storage_offset: | |
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}") | |
kwargs = { | |
"size": torch.Size(size), | |
"stride": stride, | |
"storage_offset": storage_offset, | |
} | |
return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs)) |
So basically, it's only a hack to make it work with FSDP2. Normally we shouldn't call these ops at all! I don't know if passing the "expected kwargs" to Tensor._make_wrapper_subclass()
is necessary. The FSDP2 test passes on my machine (with an earlier version of pytorch nightly), so I think it might be ok? This is why I also request review from @awgu to check what is the best practice / recommended way here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, can you add a comment in the code to explain this?
return Tensor._make_wrapper_subclass( | ||
cls, | ||
int_data.shape, | ||
dtype=scale.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel it makes more sense for the dtype of this tensor to be torch.int8
instead? Otherwise we can get something like the following, which doesn't make as much sense:
t = Int8QTLinearWeight(weight, scale)
print(t.dtype) # torch.float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's related to #442 😄. Basically .dtype
should be the dtype of gradients for autograd to work. NF4 and AQT are also doing this.
|
||
# the main difference of this tensor subclass from AffineQuantizedTensor: | ||
# 1. F.linear is differentiable i.e. backward is defined. | ||
# 2. support stochastic rounding when casting from floating point. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think stochastic_rounding=False
when casting from floating point? (L73)
|
||
@classmethod | ||
def from_float(cls, tensor: Tensor): | ||
"""Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would happen if we just always apply stochastic rounding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of semantics, I think it's better to keep from_float()
to perform normal rounding, as it is a dtype conversion. By right, stochastic rounding is only done in the optimizer step -> should be part of the optimizer logic. We use tensor subclass here so it's easier to implement.
In practice, once we do training, I think it shouldn't matter much if we stochastic round the first dtype conversion or not.
@@ -0,0 +1,42 @@ | |||
# Quantized training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! Thanks for adding such a detailed README
For 1, for the experiments run in this PR, what is the reduction in memory footprint vs bf16 training? For 2, do you have any accuracy comparisons of the model (inference time) vs bf16 |
I added eval results in the PR description. Reproduce it here for convenience. Also added PTQ eval.
For INT8 PTQ vs BF16, the loss only differs by the 5th decimal point. I double-checked to make sure I didn't make any mistakes. So at the current state, INT8 PTQ is still better, though perhaps INT8 QT needs slightly different hparams (e.g. larger LR).
Currently they don't match for the following reasons
I checked that once I change my implementation to match AQT on the above 3 pointers, the inference results match exactly in eager mode. Pointer 2 is especially important for quantized training. Empirical results from my previous experiments show that training will not converge if 127.5 is used. Likely because there will be clipping in the +ve side (i.e. +127.5 -> +127), which bias the results.
Using 1B model; bs=4, seq_len=2048 -> 8192 toks/iter; activation checkpointing (I also put this table in the PR description)
In eager mode, the reduction in memory looks correct (1B model -> 1GB reduction from BF16->INT8). However, in compile mode, there is not a lot of memory reduction compared to BF16 compile. Maybe related to transposed weight? #624 |
@gau-nernst Thanks for the detailed evaluation and explanation! I have a couple of follow-up questions but I'm OK with merging this as a prototype, but we should probably talk about the above caveats either in the README or in inline comments. Also curious if others have an opinion on this (cc @msaroufim @jerryzh168).
That's great. Can you include these 3 in the comments for your tensor subclass? This will help us when we refactor to use AQT in a future PR.
That seems a bit fragile. Do you know why this is the case? Is there anyone else in the community that has reported something similar?
Please add a TODO to investigate this. I think it's worth looking into because saving memory is an important reason for QT in the first place. |
also if you need to add any skip tests to get everything green let's do that, merge and just list out the remaining work that needs to be done in issues. |
I suspect it's due to clipping. E.g. using 127.5, when you have 2 values [-1,+1] -> scale = 127.5, values are scaled to [-127.5,+127.5]. +127.5 can stochastic round to 127 or 128, but due to INT8 clipping, it will always be clipped to 127. In the end, I'm not aware of works reporting this phenomenon. https://github.com/google/aqt also uses 127, though they are actually doing QAT. In terms of PTQ, not an expert but I think many libraries also uses 127 instead of 127.5? E.g. https://github.com/ggerganov/llama.cpp/blob/4b9afbbe9037f8a2d659097c0c7d9fce32c6494c/ggml/src/ggml-cuda/quantize.cu#L27. Though for PTQ, it's probably doesn't matter much whether 127 or 127.5 are used. |
tensor = tensor.float() | ||
|
||
# absmax symmetric quantization | ||
scale = tensor.abs().amax(-1) / 127 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so does this mean the quant_min/quant_max for int8 will be (-127, 127)? we could match this in post training quant by explicitly setting quant_min/quant_max I think
Address #554 (but don't close it)
This PR upstreams some of the exploration work done in https://github.com/gau-nernst/quantized-training
Introduction (from the new README)
This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from:
Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use stochastic rounding for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly.
There are 2 main benefits for training in this way:
Currently we only support weight-only channel-wise INT8 symmetric quantization.
In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization
[-127, 127]
. In the forward and backward pass, the weights are upcast to activations' dtype (e.g. BF16). Therefore, their gradients are also in activations' dtype.Usage
It is recommended to use optimizers from
torchao.prototype.low_bit_optim
for quantized training, because they can automatically generate efficient fused optimizer kernel fordequant->optimizer_step->quant
thanks totorch.compile()
.Results
Training loss. Run the included benchmark. This launches an LLM pre-training using a Llama2-style model with 470M parameters on TinyStories. Add
--quantize int8_weight_only
for INT8 quantized training.Memory benchmark. 1B model; bs=4, seq_len=2048 -> 8192 toks/iter; activation checkpointing
In eager mode, the reduction in memory looks correct (1B model -> 1GB reduction from BF16->INT8). However, in compile mode, there is not a lot of memory reduction compared to BF16 compile. Maybe related to transposed weight? #624
Extra results
Some extra results from https://github.com/gau-nernst/quantized-training that was done before this PR.
Llama2-style LLM pre-training on TinyStories
The gap is much smaller for 1B params, indicating that larger model is easier to quantize (maybe dependent on the dataset too, perhaps 1B model is too big for TinyStories). Validation loss has the same gap as training loss.
ViT fine-tuning
Fine-tune
timm/vit_giant_patch14_dinov2.lvd142m
(1B params) on RESISC45LLM fine-tuning
Fine-tune SmolLM-1.7B on MetaMathQA
Using LR=1e-5, only INT8 QT model can learn, while BF16 model cannot learn at all, probably due to LR being too small (BF16 only has ~3 decimal precision). Increasing LR to 1e-4, BF16 model can be trained now. Due to different LR, it's hard to compare BF16 LR=1e-4 and INT8 QT LR=1e-5 directly. Will update this if I have time to re-run INT8 QT with LR=1e-4.
Note on optimizer
Optimizer logic is still done in high precision (FP32 or BF16). To minimize errors, we should only re-quantize params once, at the end of the optimizer logic. However, existing PyTorch optimizers may modify the param in-place multiple times. e.g. AdamW
https://github.com/pytorch/pytorch/blob/32be3e942c3251dc50892334c6614a89327c122c/torch/optim/adamw.py#L384
Therefore, this PR also adds an alternative implementation of AdamW in
torchao.prototype.low_bit_optim.AdamW
, which only applies in-place update of param in the final step.Future ideas