-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add experimental INT8 quantized training #644
Changes from 28 commits
3d42329
eca170a
dd162a8
b286f5d
8a84aca
ea47c7d
f20486b
3415244
5d0e658
d753476
9c77800
158eb61
db0290f
1c32b78
ff69121
45342ba
f1587a2
7f9102a
0428330
001422c
2eb2787
d39caba
de6aa25
adbe47d
3fdf776
ea0ee4f
36d0e1a
2360a97
9e19104
6bc7621
6646c0b
00e25cf
927a6d1
8377707
de49e8b
f80ac97
e375c3d
6396a95
662c61f
cc90298
f1c588b
f444fa6
640ec2d
dad6560
4924e8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,148 @@ | ||||||||||||||||||||||||||||||||||||||||
# pre-train a mini Llama2 on TinyStories with INT8 quantized training | ||||||||||||||||||||||||||||||||||||||||
# pip install transformers sentencepiece wandb | ||||||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||||||
# BF16 baseline: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 | ||||||||||||||||||||||||||||||||||||||||
# INT8 QT: python benchmarks/benchmark_int8_qt.py --seed 2024 --n_steps 10_000 --quantize int8_weight_only | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
import argparse | ||||||||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||
import wandb | ||||||||||||||||||||||||||||||||||||||||
from tqdm import tqdm | ||||||||||||||||||||||||||||||||||||||||
from transformers import LlamaConfig, LlamaForCausalLM | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
from torchao.prototype import low_bit_optim | ||||||||||||||||||||||||||||||||||||||||
from torchao.prototype.quantized_training import int8_weight_only_quantized_training | ||||||||||||||||||||||||||||||||||||||||
from torchao.quantization.quant_api import quantize_ | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def get_loss(model: LlamaForCausalLM, batch: torch.Tensor): | ||||||||||||||||||||||||||||||||||||||||
return model(batch, labels=batch).loss | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
def get_tinystories(): | ||||||||||||||||||||||||||||||||||||||||
save_path = Path("tinystories.bin") | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
if not save_path.exists(): | ||||||||||||||||||||||||||||||||||||||||
import sentencepiece as spm | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps create a seperate quantized training folder along with the dependencies for hf_hub and sentencepiece |
||||||||||||||||||||||||||||||||||||||||
from huggingface_hub import hf_hub_download | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
tokenizer_path = hf_hub_download("meta-llama/Llama-2-7b", "tokenizer.model") | ||||||||||||||||||||||||||||||||||||||||
tokenizer = spm.SentencePieceProcessor(tokenizer_path) | ||||||||||||||||||||||||||||||||||||||||
assert tokenizer.vocab_size() < (1 << 16) # make sure we can use uint16 | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# do everything in memory. we have enough RAM | ||||||||||||||||||||||||||||||||||||||||
filepath = hf_hub_download( | ||||||||||||||||||||||||||||||||||||||||
"roneneldan/TinyStories", | ||||||||||||||||||||||||||||||||||||||||
"TinyStoriesV2-GPT4-train.txt", | ||||||||||||||||||||||||||||||||||||||||
repo_type="dataset", | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
stories = open(filepath).read().split("\n<|endoftext|>\n") | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
tokens_list = [] | ||||||||||||||||||||||||||||||||||||||||
chunk_size = 10_000 | ||||||||||||||||||||||||||||||||||||||||
for i in tqdm(range(0, len(stories), chunk_size), desc="Tokenizing TinyStories"): | ||||||||||||||||||||||||||||||||||||||||
chunk = stories[i : min(i + chunk_size, len(stories))] | ||||||||||||||||||||||||||||||||||||||||
tokens_list.extend(tokenizer.Encode(chunk, add_bos=True, add_eos=True, num_threads=4)) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
total_size = sum(len(x) for x in tokens_list) | ||||||||||||||||||||||||||||||||||||||||
mmap_tokens = np.memmap(save_path, dtype=np.uint16, mode="w+", shape=total_size) | ||||||||||||||||||||||||||||||||||||||||
i = 0 | ||||||||||||||||||||||||||||||||||||||||
for tokens in tokens_list: | ||||||||||||||||||||||||||||||||||||||||
mmap_tokens[i : i + len(tokens)] = tokens | ||||||||||||||||||||||||||||||||||||||||
i += len(tokens) | ||||||||||||||||||||||||||||||||||||||||
mmap_tokens.flush() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
tokens = np.memmap(save_path, dtype=np.uint16, mode="r") | ||||||||||||||||||||||||||||||||||||||||
return torch.from_numpy(tokens) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||
parser = argparse.ArgumentParser() | ||||||||||||||||||||||||||||||||||||||||
# default config is 470M | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--d_model", type=int, default=1024) | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--depth", type=int, default=24) | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--ffn_size", type=int, default=4096) | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--head_dim", type=int, default=64) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--quantize") | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--activation_checkpointing", action="store_true") | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--n_steps", type=int, default=1000) | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--batch_size", type=int, default=4) | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--seq_len", type=int, default=2048) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--optim", default="AdamW") | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--lr", type=float, default=3e-4) | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--weight_decay", type=float, default=1e-2) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--project", default="int8_quantized_training") | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--run_name") | ||||||||||||||||||||||||||||||||||||||||
parser.add_argument("--seed", type=int) | ||||||||||||||||||||||||||||||||||||||||
args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
if args.seed is not None: | ||||||||||||||||||||||||||||||||||||||||
torch.manual_seed(args.seed) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
config = LlamaConfig( | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you also try the llama model we have in AO? CUrious what was more convenient here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't. Will give it a try There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I gave it a try and faced some difficulties, mainly because the model was written for pure inference only. To be specific, I need to change
So for now I think I won't use the built-in Llama for this PR. Perhaps we can modify Llama model definition in a separate PR to support training, then I will change this script to use the built-in Llama. |
||||||||||||||||||||||||||||||||||||||||
hidden_size=args.d_model, | ||||||||||||||||||||||||||||||||||||||||
intermediate_size=args.ffn_size, | ||||||||||||||||||||||||||||||||||||||||
num_hidden_layers=args.depth, | ||||||||||||||||||||||||||||||||||||||||
num_attention_heads=args.d_model // args.head_dim, | ||||||||||||||||||||||||||||||||||||||||
max_position_embeddings=args.seq_len, | ||||||||||||||||||||||||||||||||||||||||
use_cache=False, | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
model = LlamaForCausalLM(config).bfloat16().cuda() | ||||||||||||||||||||||||||||||||||||||||
if args.activation_checkpointing: | ||||||||||||||||||||||||||||||||||||||||
model.gradient_checkpointing_enable() | ||||||||||||||||||||||||||||||||||||||||
if args.quantize == "int8_weight_only": | ||||||||||||||||||||||||||||||||||||||||
quantize_(model, int8_weight_only_quantized_training()) | ||||||||||||||||||||||||||||||||||||||||
elif args.quantize is not None: | ||||||||||||||||||||||||||||||||||||||||
raise ValueError(f"Unsupported quantize={args.quantize}") | ||||||||||||||||||||||||||||||||||||||||
print(f"No. of params: {sum(p.numel() for p in model.parameters()):,}") | ||||||||||||||||||||||||||||||||||||||||
print(f"No. of buffers: {sum(p.numel() for p in model.buffers()):,}") | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
# turn off these flags (set by quantize_()) to speed up compile time | ||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. huh interesting were the flags that bad? cc @HDCharles @jerryzh168 in case we need to change the defaults There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think they are fine for inference. Maybe ok for training too, but I didn't want to wait 🤣 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can turn it off setting ao/torchao/quantization/quant_api.py Line 268 in e7fc0ed
|
||||||||||||||||||||||||||||||||||||||||
torch._inductor.config.coordinate_descent_tuning = False | ||||||||||||||||||||||||||||||||||||||||
torch._inductor.config.coordinate_descent_check_all_directions = False | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
optim = getattr(low_bit_optim, args.optim)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
data = get_tinystories().cuda() | ||||||||||||||||||||||||||||||||||||||||
run = wandb.init(dir="/tmp", config=args, project=args.project, name=args.run_name) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
step = 0 | ||||||||||||||||||||||||||||||||||||||||
log_interval = 50 | ||||||||||||||||||||||||||||||||||||||||
pbar = tqdm(total=args.n_steps, dynamic_ncols=True) | ||||||||||||||||||||||||||||||||||||||||
model.train() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
while step < args.n_steps: | ||||||||||||||||||||||||||||||||||||||||
# randomly select a continuous chunk, then reshape it | ||||||||||||||||||||||||||||||||||||||||
idx = torch.randint(0, data.shape[0] - args.batch_size * args.seq_len, (1,)).item() | ||||||||||||||||||||||||||||||||||||||||
batch = data[idx : idx + args.batch_size * args.seq_len].view(args.batch_size, args.seq_len).long() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
loss = torch.compile(get_loss)(model, batch) | ||||||||||||||||||||||||||||||||||||||||
loss.backward() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
if step % log_interval == 0: | ||||||||||||||||||||||||||||||||||||||||
log_dict = dict( | ||||||||||||||||||||||||||||||||||||||||
loss=loss.item(), | ||||||||||||||||||||||||||||||||||||||||
lr=optim.param_groups[0]["lr"], | ||||||||||||||||||||||||||||||||||||||||
max_memory_allocated=torch.cuda.max_memory_allocated(), | ||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||
run.log(log_dict, step=step) | ||||||||||||||||||||||||||||||||||||||||
pbar.set_postfix(loss=log_dict["loss"]) | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
optim.step() | ||||||||||||||||||||||||||||||||||||||||
optim.zero_grad() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
step += 1 | ||||||||||||||||||||||||||||||||||||||||
pbar.update() | ||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||
run.finish() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious were you running into ooms without this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for the experiments I ran here I think. I copied it over from my repo, and I think this flag is generally good.