Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experimental INT8 quantized training #644

Merged
merged 45 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3d42329
initial commit
gau-nernst Aug 9, 2024
eca170a
add tests
gau-nernst Aug 9, 2024
dd162a8
add training
gau-nernst Aug 9, 2024
b286f5d
support py3.9
gau-nernst Aug 9, 2024
8a84aca
skip test for torch<2.3
gau-nernst Aug 9, 2024
ea47c7d
fix pytest
gau-nernst Aug 9, 2024
f20486b
fix adamw
gau-nernst Aug 9, 2024
3415244
add some FSDP ops
gau-nernst Aug 9, 2024
5d0e658
add more fsdp ops
gau-nernst Aug 10, 2024
d753476
more ops
gau-nernst Aug 10, 2024
9c77800
add benchmark script
gau-nernst Aug 10, 2024
158eb61
some organisation
gau-nernst Aug 10, 2024
db0290f
add FSDP test
gau-nernst Aug 10, 2024
1c32b78
clean up
gau-nernst Aug 10, 2024
ff69121
update FSDP test
gau-nernst Aug 10, 2024
45342ba
add compile test (things are crashing)
gau-nernst Aug 10, 2024
f1587a2
fix bias
gau-nernst Aug 10, 2024
7f9102a
substantial update to tests
gau-nernst Aug 10, 2024
0428330
fix compile for FSDP
gau-nernst Aug 10, 2024
001422c
update readme. rename file
gau-nernst Aug 10, 2024
2eb2787
speed up CI
gau-nernst Aug 10, 2024
d39caba
fix typo
gau-nernst Aug 10, 2024
de6aa25
fix typo
gau-nernst Aug 10, 2024
adbe47d
typos. unset some dynamo flags
gau-nernst Aug 10, 2024
3fdf776
update readme
gau-nernst Aug 10, 2024
ea0ee4f
remove requires_grad, since it is unnecessary
gau-nernst Aug 11, 2024
36d0e1a
remove note
gau-nernst Aug 11, 2024
2360a97
Merge branch 'pytorch:main' into qt_int8
gau-nernst Aug 11, 2024
9e19104
Merge branch 'main' into qt_int8
gau-nernst Aug 13, 2024
6bc7621
don't set inductor flags
gau-nernst Aug 13, 2024
6646c0b
rename
gau-nernst Aug 13, 2024
00e25cf
update README
gau-nernst Aug 13, 2024
927a6d1
rename optimizer
gau-nernst Aug 13, 2024
8377707
Merge branch 'main' into qt_int8
gau-nernst Aug 14, 2024
de49e8b
update benchmark script
gau-nernst Aug 14, 2024
f80ac97
make compile explicit
gau-nernst Aug 14, 2024
e375c3d
update docs
gau-nernst Aug 14, 2024
6396a95
Merge branch 'main' into qt_int8
gau-nernst Aug 16, 2024
662c61f
use torch.optim.Adam to avoid FSDP optim compile bug
gau-nernst Aug 16, 2024
cc90298
update docs
gau-nernst Aug 16, 2024
f1c588b
update doc
gau-nernst Aug 16, 2024
f444fa6
update docs
gau-nernst Aug 16, 2024
640ec2d
fix CI test
gau-nernst Aug 16, 2024
dad6560
skip test
gau-nernst Aug 16, 2024
4924e8d
fix compiled test
gau-nernst Aug 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions benchmarks/benchmark_int8_qt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
# pip install transformers sentencepiece wandb
#
# BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000
# INT8 QT: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only

import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious were you running into ooms without this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for the experiments I ran here I think. I copied it over from my repo, and I think this flag is generally good.


import argparse
from pathlib import Path

import numpy as np
import torch
import wandb
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM

from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.quantization.quant_api import quantize_


def get_loss(model: LlamaForCausalLM, batch: torch.Tensor):
return model(batch, labels=batch).loss


def get_tinystories():
save_path = Path("tinystories.bin")

if not save_path.exists():
import sentencepiece as spm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps create a seperate quantized training folder along with the dependencies for hf_hub and sentencepiece

from huggingface_hub import hf_hub_download

tokenizer_path = hf_hub_download("meta-llama/Llama-2-7b", "tokenizer.model")
tokenizer = spm.SentencePieceProcessor(tokenizer_path)
assert tokenizer.vocab_size() < (1 << 16) # make sure we can use uint16

# do everything in memory. we have enough RAM
filepath = hf_hub_download(
"roneneldan/TinyStories",
"TinyStoriesV2-GPT4-train.txt",
repo_type="dataset",
)
stories = open(filepath).read().split("\n<|endoftext|>\n")

tokens_list = []
chunk_size = 10_000
for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"):
chunk = stories[i : min(i + chunk_size, len(stories))]
tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4))

total_size = sum(len(x) for x in tokens_list)
mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size)
i = 0
for tokens in tokens_list:
mmap_tokens[i : i + len(tokens)] = tokens
i += len(tokens)
mmap_tokens.flush()

tokens = np.memmap(save_path, dtype=np.uint16, mode="r")
return torch.from_numpy(tokens)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# default config is 470M
parser.add_argument("--d_model", type=int, default=1024)
parser.add_argument("--depth", type=int, default=24)
parser.add_argument("--ffn_size", type=int, default=4096)
parser.add_argument("--head_dim", type=int, default=64)

parser.add_argument("--quantize")
parser.add_argument("--activation_checkpointing", action="store_true")

parser.add_argument("--n_steps", type=int, default=1000)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--seq_len", type=int, default=2048)

parser.add_argument("--optim", default="AdamW")
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--weight_decay", type=float, default=1e-2)

parser.add_argument("--project", default="int8_quantized_training")
parser.add_argument("--run_name")
parser.add_argument("--seed", type=int)
args = parser.parse_args()

if args.seed is not None:
torch.manual_seed(args.seed)

config = LlamaConfig(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you also try the llama model we have in AO? CUrious what was more convenient here?

Copy link
Collaborator Author

@gau-nernst gau-nernst Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't. Will give it a try

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave it a try and faced some difficulties, mainly because the model was written for pure inference only. To be specific, I need to change

  1. Initialize freq_cis without initializing KV-Cache
    def setup_caches(self, max_batch_size, max_seq_length):
    if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
    return
    head_dim = self.config.dim // self.config.n_head
    max_seq_length = find_multiple(max_seq_length, 8)
    self.max_seq_length = max_seq_length
    self.max_batch_size = max_batch_size
    dtype = self.output.weight.dtype
    # For quantized layers, dtype is encoded in scales
    if hasattr(self.output, "scales"):
    dtype = self.output.scales.dtype
    elif hasattr(self.output, "scales_and_zeros"):
    dtype = self.output.scales_and_zeros.dtype
    for b in self.layers:
    b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)
    self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
    self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
  2. Don't use attention mask, just use is_causal=True directly
    y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

So for now I think I won't use the built-in Llama for this PR. Perhaps we can modify Llama model definition in a separate PR to support training, then I will change this script to use the built-in Llama.

hidden_size=args.d_model,
intermediate_size=args.ffn_size,
num_hidden_layers=args.depth,
num_attention_heads=args.d_model // args.head_dim,
max_position_embeddings=args.seq_len,
use_cache=False,
)
model = LlamaForCausalLM(config).bfloat16().cuda()
if args.activation_checkpointing:
model.gradient_checkpointing_enable()
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training())
elif args.quantize is not None:
raise ValueError(f"Unsupported quantize={args.quantize}")
print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}")
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}")

# turn off these flags (set by quantize_()) to speed up compile time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh interesting were the flags that bad? cc @HDCharles @jerryzh168 in case we need to change the defaults

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think they are fine for inference. Maybe ok for training too, but I didn't want to wait 🤣

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can turn it off setting set_inductor_config=False:

def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):

torch._inductor.config.coordinate_descent_tuning = False
torch._inductor.config.coordinate_descent_check_all_directions = False

optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

data = get_tinystories().cuda()
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name)

step = 0
log_interval = 50
pbar = tqdm(total=args.n_steps, dynamic_ncols=True)
model.train()

while step < args.n_steps:
# randomly select a continuous chunk, then reshape it
idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item()
batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long()

loss = torch.compile(get_loss)(model, batch)
loss.backward()

if step % log_interval == 0:
log_dict = dict(
loss=loss.item(),
lr=optim.param_groups[0]["lr"],
max_memory_allocated=torch.cuda.max_memory_allocated(),
)
run.log(log_dict, step=step)
pbar.set_postfix(loss=log_dict["loss"])

optim.step()
optim.zero_grad()

step += 1
pbar.update()

run.finish()
24 changes: 24 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,30 @@ def test_optim_smoke(self, optim_name, dtype, device):
optim.step()
optim.zero_grad()

@parametrize("device", _DEVICES)
def test_optim_standard_correctness(self, device):
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

optim1 = torch.optim.AdamW(model1.parameters())
optim2 = low_bit_optim.AdamW(model2.parameters())

for _ in range(2):
x = torch.randn(4, 32, device=device)

loss1 = model1(x).sum()
loss1.backward()
optim1.step()
optim1.zero_grad()

loss2 = model2(x).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
Expand Down
Loading
Loading