Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental INT8 quantized training #644

Merged
merged 45 commits into from
Aug 16, 2024
Merged

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Aug 9, 2024

Address #554 (but don't close it)

This PR upstreams some of the exploration work done in https://github.com/gau-nernst/quantized-training

Introduction (from the new README)

This folder contains experimental work on quantized training (QT). The main difference from quantization-aware training (QAT) is that in QT, we don't keep a high-precision copy of model weights. We take inspirations from:

Typically, low-precision weights cannot be trained directly due to quantization error: a small change in the quantized weight will be round down to zero. To tackle this problem, we use stochastic rounding for weight update. In simple terms, stochastic rounding will round up or down randomly, but with a higher chance if it is closer to that direction. For example, 0.8 will have 80% chance of rounding up and 20% of rounding down. It also follows that on average, stochastic rounding will estimate the floating point value exactly.

There are 2 main benefits for training in this way:

  1. Reduce memory footprint. Also reduce communication bandwidth in distributed setting.
  2. What you train is what you serve (WYTIWYS).

Currently we only support weight-only channel-wise INT8 symmetric quantization.

In this recipe, all linear weights are quantized to INT8 using channel-wise symmetric quantization [-127, 127]. In the forward and backward pass, the weights are upcast to activations' dtype (e.g. BF16). Therefore, their gradients are also in activations' dtype.

Usage

from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.prototype.low_bit_optim import AdamW
from torchao.quantization.quant_api import quantize_

model = ...
quantize_(model, int8_weight_only_quantized_training())

optim = AdamW(model.parameters(), lr=3e-4)

It is recommended to use optimizers from torchao.prototype.low_bit_optim for quantized training, because they can automatically generate efficient fused optimizer kernel for dequant->optimizer_step->quant thanks to torch.compile().

Results

Training loss. Run the included benchmark. This launches an LLM pre-training using a Llama2-style model with 470M parameters on TinyStories. Add --quantize int8_weight_only for INT8 quantized training.

python python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
image

Memory benchmark. 1B model; bs=4, seq_len=2048 -> 8192 toks/iter; activation checkpointing

Model Peak memory (GB)
BF16 eager 11.06847
BF16 compile 10.16915
INT8 QT eager 10.11437
INT8 QT compile 10.03365

In eager mode, the reduction in memory looks correct (1B model -> 1GB reduction from BF16->INT8). However, in compile mode, there is not a lot of memory reduction compared to BF16 compile. Maybe related to transposed weight? #624
 

Extra results

Some extra results from https://github.com/gau-nernst/quantized-training that was done before this PR.

Llama2-style LLM pre-training on TinyStories

470M params 1B params
image image

The gap is much smaller for 1B params, indicating that larger model is easier to quantize (maybe dependent on the dataset too, perhaps 1B model is too big for TinyStories). Validation loss has the same gap as training loss.

Model TinyStories validation loss
470M BF16 0.9904
470M BF16 + INT8 PTQ 0.9904
470M INT8 QT 1.0389
470M INT8 QT + 8-bit Adam 1.0927
1B BF16 0.9898
1B BF16 + INT8 PTQ 0.9898
1B INT8 QT 1.0000

ViT fine-tuning

Fine-tune timm/vit_giant_patch14_dinov2.lvd142m (1B params) on RESISC45

image
Model RESISC45 val acc
BF16 model 93.94%
BF16 model + INT8 PTQ 93.92%
INT8 QT model 92.40%

LLM fine-tuning

Fine-tune SmolLM-1.7B on MetaMathQA

image

Using LR=1e-5, only INT8 QT model can learn, while BF16 model cannot learn at all, probably due to LR being too small (BF16 only has ~3 decimal precision). Increasing LR to 1e-4, BF16 model can be trained now. Due to different LR, it's hard to compare BF16 LR=1e-4 and INT8 QT LR=1e-5 directly. Will update this if I have time to re-run INT8 QT with LR=1e-4.

Model GSM8K acc
BF16 model (LR=1e-4) 14.86
INT8 QT model (LR=1e-5) 19.11

Note on optimizer

Optimizer logic is still done in high precision (FP32 or BF16). To minimize errors, we should only re-quantize params once, at the end of the optimizer logic. However, existing PyTorch optimizers may modify the param in-place multiple times. e.g. AdamW

https://github.com/pytorch/pytorch/blob/32be3e942c3251dc50892334c6614a89327c122c/torch/optim/adamw.py#L384
 
Therefore, this PR also adds an alternative implementation of AdamW in torchao.prototype.low_bit_optim.AdamW, which only applies in-place update of param in the final step.

Future ideas

  • INT8 activation x INT8 weight. This can potentially leverage INT8 Tensor Cores, which is 2x faster than FP16/BF16 Tensor Cores.
  • INT4 weight only (with group-wise quantization). This can be used with INT4 tinygemm deployment in mind (or other optimized INT4 kernels).
  • FP8 activation x FP8 weight. The current FP8 training recipe can be seen as a form of QAT, which maintains a high-precision copy of model weights. We can eliminate the high-precision copy.

Copy link

pytorch-bot bot commented Aug 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/644

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4924e8d with merge base 0b66ff0 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 9, 2024
@andrewor14 andrewor14 self-requested a review August 9, 2024 17:13
from torchao.quantization.quant_api import quantize_

model = ...
quantize_(model, int8_weight_only_quantized_training())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this is nice

optim_int8.zero_grad()


class TestFSDP2(FSDPTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like we could probably add this as part of standard test suite, that we can use to sanity check if FSDP is supported for any dtype/tensor subclasses

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gau-nernst, looks great! One question I have is do the forward numerics match the inference path? E.g.

quantize_(model, int8_weight_only())  # vs
quantize_(model, int8_weight_only_quantized_training())

Do we want them to match?

Also, for Llama2, do you have the eval accuracies/perplexities compared to bf16 training?

torchao/prototype/quantized_training/int8.py Show resolved Hide resolved
tensor = tensor.round()

# NOTE: is clipping necessary?
tensor = tensor.clip(-128, 127).to(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually is it -127 or -128? From your PR description it says -127

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this clipping is to ensure that it is within INT8 range, not related to the scaling step before this. By right, for stochastic rounding to work correctly, clipping shouldn't occur (which bias the results), that's why I add the comment there.

torchao/prototype/quantized_training/int8.py Show resolved Hide resolved
# don't do anything. workaround for FSDP2. might give unexpected or wrong results.
@Int8QTLinearWeight.implements([aten.view.default, aten.as_strided.default])
def _(func, types, args, kwargs):
out = Int8QTLinearWeight(args[0].int_data, args[0].scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we don't need to do anything here? What happens if we call view on the inner tensors and return a copy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These ops are called by FSDP2 by these lines (there are a few mores, you can ctrl+F, but these are the first hits when FSDP2 runs)

https://github.com/pytorch/pytorch/blob/cd565bc45554f6c7d5659bff19a270bda58ce71d/torch/distributed/_composable/fsdp/_fsdp_param.py#L349
https://github.com/pytorch/pytorch/blob/cd565bc45554f6c7d5659bff19a270bda58ce71d/torch/distributed/_composable/fsdp/_fsdp_param.py#L466-L471

From what I understand, FSDP2 expects a "view" here i.e. not a copy. Since we do channel-wise quantization here, a view doesn't quite make sense since it changes the dimensions -> quantization scale should change too (i.e. group size is tied to tensor dims). Propagate the ops to inner tensor may not make much sense too. For example, as_strided can't be called on .scale, since scale has different shape from the outer tensor.

I looked into NF4 impl, and it seems like they also keep the inner tensors as is, only update the kwargs passing to Tensor._make_wrapper_subclass(), which change the outer tensor appearance.

@implements(
[
aten.view.default,
]
)
@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=")
def nf4_view(aten_op, args, kwargs=None):
nf4tensor = args[0]
size = args[1]
if size[0] != -1:
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}")
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
updated_attrs.update({
"size": [nf4tensor.numel()],
"stride": (1, ),
})
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
@implements(
[
aten.as_strided.default,
]
)
@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=")
def nf4_as_strided(aten_op, args, kwargs=None):
nf4tensor = args[0]
size = args[1]
stride = tuple(args[2])
storage_offset = args[3]
if math.prod(size) != nf4tensor.numel():
raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}")
if stride != make_contiguous_strides_for(size):
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}")
if nf4tensor.storage_offset() != storage_offset:
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {nf4tensor.storage_offset()} but got {storage_offset}")
kwargs = {
"size": torch.Size(size),
"stride": stride,
"storage_offset": storage_offset,
}
return NF4Tensor(*construct_nf4_args(nf4tensor, kwargs))

So basically, it's only a hack to make it work with FSDP2. Normally we shouldn't call these ops at all! I don't know if passing the "expected kwargs" to Tensor._make_wrapper_subclass() is necessary. The FSDP2 test passes on my machine (with an earlier version of pytorch nightly), so I think it might be ok? This is why I also request review from @awgu to check what is the best practice / recommended way here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, can you add a comment in the code to explain this?

return Tensor._make_wrapper_subclass(
cls,
int_data.shape,
dtype=scale.dtype,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it makes more sense for the dtype of this tensor to be torch.int8 instead? Otherwise we can get something like the following, which doesn't make as much sense:

t = Int8QTLinearWeight(weight, scale)
print(t.dtype)  # torch.float32

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's related to #442 😄. Basically .dtype should be the dtype of gradients for autograd to work. NF4 and AQT are also doing this.


# the main difference of this tensor subclass from AffineQuantizedTensor:
# 1. F.linear is differentiable i.e. backward is defined.
# 2. support stochastic rounding when casting from floating point.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think stochastic_rounding=False when casting from floating point? (L73)

torchao/prototype/quantized_training/int8.py Show resolved Hide resolved

@classmethod
def from_float(cls, tensor: Tensor):
"""Convert a float tensor into INT8 quantized weight. No stochastic rounding is performed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if we just always apply stochastic rounding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of semantics, I think it's better to keep from_float() to perform normal rounding, as it is a dtype conversion. By right, stochastic rounding is only done in the optimizer step -> should be part of the optimizer logic. We use tensor subclass here so it's easier to implement.

In practice, once we do training, I think it shouldn't matter much if we stochastic round the first dtype conversion or not.

@@ -0,0 +1,42 @@
# Quantized training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Thanks for adding such a detailed README

@supriyar
Copy link
Contributor

There are 2 main benefits for training in this way:

Reduce memory footprint. Also reduce communication bandwidth in distributed setting.
What you train is what you serve (WYTIWYS).

For 1, for the experiments run in this PR, what is the reduction in memory footprint vs bf16 training?

For 2, do you have any accuracy comparisons of the model (inference time) vs bf16

@msaroufim msaroufim requested a review from vkuzo August 13, 2024 22:46
@gau-nernst
Copy link
Collaborator Author

@andrewor14 @supriyar

do you have the eval accuracies/perplexities compared to bf16 training?
do you have any accuracy comparisons of the model (inference time) vs bf16

I added eval results in the PR description. Reproduce it here for convenience. Also added PTQ eval.

Model TinyStories validation loss
470M BF16 0.9904
470M BF16 + INT8 PTQ 0.9904
470M INT8 QT 1.0389
470M INT8 QT + 8-bit Adam 1.0927
1B BF16 0.9898
1B BF16 + INT8 PTQ 0.9898
1B INT8 QT 1.0000

For INT8 PTQ vs BF16, the loss only differs by the 5th decimal point. I double-checked to make sure I didn't make any mistakes. So at the current state, INT8 PTQ is still better, though perhaps INT8 QT needs slightly different hparams (e.g. larger LR).

One question I have is do the forward numerics match the inference path

Currently they don't match for the following reasons

  1. Precision: AQT keeps original dtype when doing quantization, while I upcast to FP32 before quantization, and downcast after that
  2. Calculate scale: AQT uses input.abs().amax() / 127.5, while I use input.abs().amax() / 127
  3. Apply the scale: AQT uses input * (1 / scale), while I use input / scale

I checked that once I change my implementation to match AQT on the above 3 pointers, the inference results match exactly in eager mode.

Pointer 2 is especially important for quantized training. Empirical results from my previous experiments show that training will not converge if 127.5 is used. Likely because there will be clipping in the +ve side (i.e. +127.5 -> +127), which bias the results.

for the experiments run in this PR, what is the reduction in memory footprint vs bf16 training?

Using 1B model; bs=4, seq_len=2048 -> 8192 toks/iter; activation checkpointing (I also put this table in the PR description)

Model Peak memory (GB)
BF16 eager 11.06847
BF16 compile 10.16915
INT8 QT eager 10.11437
INT8 QT compile 10.03365

In eager mode, the reduction in memory looks correct (1B model -> 1GB reduction from BF16->INT8). However, in compile mode, there is not a lot of memory reduction compared to BF16 compile. Maybe related to transposed weight? #624

@andrewor14
Copy link
Contributor

@gau-nernst Thanks for the detailed evaluation and explanation! I have a couple of follow-up questions but I'm OK with merging this as a prototype, but we should probably talk about the above caveats either in the README or in inline comments. Also curious if others have an opinion on this (cc @msaroufim @jerryzh168).

I checked that once I change my implementation to match AQT on the above 3 pointers, the inference results match exactly in eager mode.

That's great. Can you include these 3 in the comments for your tensor subclass? This will help us when we refactor to use AQT in a future PR.

Pointer 2 is especially important for quantized training. Empirical results from my previous experiments show that training will not converge if 127.5 is used. Likely because there will be clipping in the +ve side (i.e. +127.5 -> +127), which bias the results.

That seems a bit fragile. Do you know why this is the case? Is there anyone else in the community that has reported something similar?

In eager mode, the reduction in memory looks correct (1B model -> 1GB reduction from BF16->INT8). However, in compile mode, there is not a lot of memory reduction compared to BF16 compile. Maybe related to transposed weight? #624

Please add a TODO to investigate this. I think it's worth looking into because saving memory is an important reason for QT in the first place.

@msaroufim
Copy link
Member

also if you need to add any skip tests to get everything green let's do that, merge and just list out the remaining work that needs to be done in issues.

@gau-nernst
Copy link
Collaborator Author

That seems a bit fragile. Do you know why this is the case? Is there anyone else in the community that has reported something similar?

I suspect it's due to clipping. E.g. using 127.5, when you have 2 values [-1,+1] -> scale = 127.5, values are scaled to [-127.5,+127.5]. +127.5 can stochastic round to 127 or 128, but due to INT8 clipping, it will always be clipped to 127. In the end, E[SR(x)] != x. This will only happen to the largest positive value. Though it does seem surprising that this small difference can make such a big impact.

I'm not aware of works reporting this phenomenon. https://github.com/google/aqt also uses 127, though they are actually doing QAT. In terms of PTQ, not an expert but I think many libraries also uses 127 instead of 127.5? E.g. https://github.com/ggerganov/llama.cpp/blob/4b9afbbe9037f8a2d659097c0c7d9fce32c6494c/ggml/src/ggml-cuda/quantize.cu#L27. Though for PTQ, it's probably doesn't matter much whether 127 or 127.5 are used.

@msaroufim msaroufim merged commit 8c6b4f9 into pytorch:main Aug 16, 2024
16 checks passed
tensor = tensor.float()

# absmax symmetric quantization
scale = tensor.abs().amax(-1) / 127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so does this mean the quant_min/quant_max for int8 will be (-127, 127)? we could match this in post training quant by explicitly setting quant_min/quant_max I think

@gau-nernst gau-nernst deleted the qt_int8 branch August 17, 2024 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants