From 25bb515ef06b1cc12a6de3989a91f89056cc8d68 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 6 Jul 2024 14:39:35 +0800 Subject: [PATCH 1/6] update benchmark --- benchmarks/benchmark_low_bit_adam.py | 10 +++++++--- torchao/prototype/low_bit_optim/README.md | 20 +++++++++----------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 63fdeff3dd..79d608176e 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -92,6 +92,7 @@ def get_parser(): parser.add_argument("--project") parser.add_argument("--run_name", default="debug") parser.add_argument("--profile", action="store_true") + parser.add_argument("--seed", type=int) return parser @@ -155,6 +156,8 @@ def evaluate_model(model, args): if args.profile: args.n_epochs = 1 + if args.seed is not None: + torch.manual_seed(args.seed) for k, v in vars(args).items(): print(f"{k}: {v}") @@ -176,11 +179,11 @@ def evaluate_model(model, args): grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") - start_time = datetime.datetime.now() step = 0 for epoch_idx in range(args.n_epochs): model.train() prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext() + start_time = datetime.datetime.now() with prof: for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"): @@ -212,9 +215,10 @@ def evaluate_model(model, args): prof.export_chrome_trace("trace.json") else: + print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}") + val_acc = evaluate_model(model, args) print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") logger.log(dict(val_acc=val_acc), step=step) - print(f"Time taken: {(datetime.datetime.now() - start_time)}") - print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index da3cc064f5..bf9cf34c17 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -31,17 +31,15 @@ NOTE: Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). -Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER: - -Adam impl | max memory (GB) | time taken | accuracy ------------|-----------------|------------|---------- -PyTorch | 12.98 | 10m 08s | 87.70 -bnb 8-bit | 8.31 | 8m 38s | 86.22 -ao 8-bit | 8.32 | 10m 54s | 86.67 -lpmm 4-bit | 7.72 | 7m 48s | 84.70 -ao 4-bit | 7.72 | 9m 17s | 85.60 - -NOTE: time taken includes validation time, and compile time for torchao optimizers. +Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 8, on 4070Ti SUPER, with fixed random seed: + +Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy +-----------|-----------------|--------------------- ----|---------- +PyTorch | 12.94 | 8m 18s | 91.14 +bnb 8-bit | 8.31 | 6m 50s | 90.67 +ao 8-bit | 8.32 | 9m 04s | 90.71 +lpmm 4-bit | 7.72 | 5m 59s | 89.97 +ao 4-bit | 7.72 | 7m 00s | 89.94 ## Credits From ac191730f9b9a89fbc28552be2b02030515d2df8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 6 Jul 2024 14:56:16 +0800 Subject: [PATCH 2/6] add rank1 option to lpmm --- benchmarks/benchmark_low_bit_adam.py | 1 + .../code/benchmarks/benchmark_low_bit_adam.py | 225 ++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 79d608176e..876d4add67 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -37,6 +37,7 @@ Adam8bitAo=Adam8bit, Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), Adam4bitAo=Adam4bit, + Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")) ) diff --git a/wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py b/wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py new file mode 100644 index 0000000000..876d4add67 --- /dev/null +++ b/wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py @@ -0,0 +1,225 @@ +# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git +# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core +# +# python benchmark_low_bit_adam.py \ +# --model "timm/vit_base_patch16_224.augreg_in21k" \ +# --amp bf16 \ +# --optim Adam +# +# See OPTIM_MAP for the available optimizer options +# To profile and export chrome trace, set --profile +# To enable cosine learning rate scheduler, set --cosine_lr_scheduler + +import argparse +import datetime +import math +from contextlib import nullcontext +from functools import partial +from pathlib import Path + +import bitsandbytes as bnb +import datasets +import lpmm +import timm +import torch +import torch.nn.functional as F +from torch.profiler import ProfilerActivity, profile +from torch.utils.data import DataLoader +from torchvision.transforms import v2 +from tqdm import tqdm + +from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit + +# lpmm doesn't have Adam, only AdamW +OPTIM_MAP = dict( + Adam=torch.optim.Adam, + Adam8bitBnb=bnb.optim.Adam8bit, + Adam8bitAo=Adam8bit, + Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), + Adam4bitAo=Adam4bit, + Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")) +) + + +class CosineSchedule: + def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None: + self.lr = lr + self.final_lr = 0 + self.total_steps = total_steps + self.warmup_steps = round(total_steps * warmup) + + def get_lr(self, step: int) -> float: + if step < self.warmup_steps: + return self.lr * step / self.warmup_steps + if step < self.total_steps: + progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) + return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi)) + return self.final_lr + + +class WandbLogger: + def __init__(self, args): + if args.project is not None and not args.profile: + import wandb + + Path("wandb_logs").mkdir(exist_ok=True) + self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs") + + else: + self.run = None + + def log(self, *args, **kwargs): + if self.run is not None: + self.run.log(*args, **kwargs) + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True) + + parser.add_argument("--amp", default="none") + parser.add_argument("--channels_last", action="store_true") + parser.add_argument("--compile", action="store_true") + + parser.add_argument("--n_epochs", type=int, default=10) + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--n_workers", type=int, default=4) + + parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys()) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--weight_decay", type=float, default=0) + parser.add_argument("--cosine_lr_scheduler", action="store_true") + + parser.add_argument("--project") + parser.add_argument("--run_name", default="debug") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--seed", type=int) + return parser + + +def get_dloader(args, training: bool): + transforms = [v2.ToImage()] + + if training: + transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()]) + else: + transforms.extend([v2.Resize(256), v2.CenterCrop(224)]) + + transforms.append(v2.ToDtype(torch.float32, scale=True)) + transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) + transforms = v2.Compose(transforms) + + # use dataset from HF so download is fast + ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation") + ds = ds.select_columns(["image", "label"]) + ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"])) + + return DataLoader( + ds, + batch_size=args.batch_size, + shuffle=training, + num_workers=args.n_workers, + pin_memory=training, + drop_last=training, + ) + + +def get_amp_ctx(amp): + dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp] + return torch.autocast("cuda", dtype=dtype, enabled=amp != "none") + + +@torch.no_grad() +def evaluate_model(model, args): + model.eval() + val_dloader = get_dloader(args, False) + + all_labels = [] + all_preds = [] + + for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"): + all_labels.append(batch["label"].clone()) + if args.channels_last: + batch["image"] = batch["image"].to(memory_format=torch.channels_last) + + with get_amp_ctx(args.amp): + all_preds.append(model(batch["image"].cuda()).argmax(1).cpu()) + + all_labels = torch.cat(all_labels, dim=0) + all_preds = torch.cat(all_preds, dim=0) + + acc = (all_labels == all_preds).float().mean() + return acc + + +if __name__ == "__main__": + args = get_parser().parse_args() + + if args.profile: + args.n_epochs = 1 + if args.seed is not None: + torch.manual_seed(args.seed) + + for k, v in vars(args).items(): + print(f"{k}: {v}") + + # wandb is only enabled when args.project is set and args.profile is False + logger = WandbLogger(args) + dloader = get_dloader(args, True) + print(f"Train dataset: {len(dloader.dataset):,} images") + + model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda() + if args.channels_last: + model.to(memory_format=torch.channels_last) + if args.compile: + model.compile(fullgraph=True) + print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") + + optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay) + lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) + + grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") + + step = 0 + for epoch_idx in range(args.n_epochs): + model.train() + prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext() + start_time = datetime.datetime.now() + + with prof: + for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"): + if args.channels_last: + batch["image"] = batch["image"].to(memory_format=torch.channels_last) + + with get_amp_ctx(args.amp): + loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda()) + grad_scaler.scale(loss).backward() + + if args.cosine_lr_scheduler: + lr = lr_schedule.get_lr(step) + for param_group in optim.param_groups: + param_group["lr"] = lr + + if step % 100 == 0: + logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step) + + grad_scaler.step(optim) + grad_scaler.update() + optim.zero_grad() + + step += 1 + + if args.profile and step == 20: + break + + if args.profile: + prof.export_chrome_trace("trace.json") + + else: + print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}") + + val_acc = evaluate_model(model, args) + print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") + logger.log(dict(val_acc=val_acc), step=step) + + print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") From 8d368b6f086b302de97f9b99de24ec1280745f7d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 6 Jul 2024 14:56:46 +0800 Subject: [PATCH 3/6] add comma --- benchmarks/benchmark_low_bit_adam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 876d4add67..716541ee9c 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -37,7 +37,7 @@ Adam8bitAo=Adam8bit, Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), Adam4bitAo=Adam4bit, - Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")) + Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")), ) From d11002129a3c7049d138d4879e56944b685406eb Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 6 Jul 2024 16:08:58 +0800 Subject: [PATCH 4/6] update readme --- torchao/prototype/low_bit_optim/README.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index bf9cf34c17..7a108959f7 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -33,13 +33,16 @@ Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 8, on 4070Ti SUPER, with fixed random seed: -Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy ------------|-----------------|--------------------- ----|---------- -PyTorch | 12.94 | 8m 18s | 91.14 -bnb 8-bit | 8.31 | 6m 50s | 90.67 -ao 8-bit | 8.32 | 9m 04s | 90.71 -lpmm 4-bit | 7.72 | 5m 59s | 89.97 -ao 4-bit | 7.72 | 7m 00s | 89.94 +Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy +---------------|-----------------|--------------------------|---------- +PyTorch | 12.94 | 8m 18s | 91.14 +bnb 8-bit | 8.31 | 6m 50s | 90.67 +ao 8-bit | 8.32 | 9m 04s | 90.71 +lpmm 4-bit | 7.72 | 5m 59s | 89.97 +ao 4-bit | 7.72 | 7m 00s | 89.94 +lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71 + +(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details. ## Credits From daaa0b0efcac020b0760897075bd435cd01b1f64 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 6 Jul 2024 18:42:44 +0800 Subject: [PATCH 5/6] remove unwanted file --- .../code/benchmarks/benchmark_low_bit_adam.py | 225 ------------------ 1 file changed, 225 deletions(-) delete mode 100644 wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py diff --git a/wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py b/wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py deleted file mode 100644 index 876d4add67..0000000000 --- a/wandb_logs/wandb/run-20240706_145524-9zfsbx2r/files/code/benchmarks/benchmark_low_bit_adam.py +++ /dev/null @@ -1,225 +0,0 @@ -# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git -# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core -# -# python benchmark_low_bit_adam.py \ -# --model "timm/vit_base_patch16_224.augreg_in21k" \ -# --amp bf16 \ -# --optim Adam -# -# See OPTIM_MAP for the available optimizer options -# To profile and export chrome trace, set --profile -# To enable cosine learning rate scheduler, set --cosine_lr_scheduler - -import argparse -import datetime -import math -from contextlib import nullcontext -from functools import partial -from pathlib import Path - -import bitsandbytes as bnb -import datasets -import lpmm -import timm -import torch -import torch.nn.functional as F -from torch.profiler import ProfilerActivity, profile -from torch.utils.data import DataLoader -from torchvision.transforms import v2 -from tqdm import tqdm - -from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit - -# lpmm doesn't have Adam, only AdamW -OPTIM_MAP = dict( - Adam=torch.optim.Adam, - Adam8bitBnb=bnb.optim.Adam8bit, - Adam8bitAo=Adam8bit, - Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), - Adam4bitAo=Adam4bit, - Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")) -) - - -class CosineSchedule: - def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None: - self.lr = lr - self.final_lr = 0 - self.total_steps = total_steps - self.warmup_steps = round(total_steps * warmup) - - def get_lr(self, step: int) -> float: - if step < self.warmup_steps: - return self.lr * step / self.warmup_steps - if step < self.total_steps: - progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) - return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi)) - return self.final_lr - - -class WandbLogger: - def __init__(self, args): - if args.project is not None and not args.profile: - import wandb - - Path("wandb_logs").mkdir(exist_ok=True) - self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs") - - else: - self.run = None - - def log(self, *args, **kwargs): - if self.run is not None: - self.run.log(*args, **kwargs) - - -def get_parser(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", required=True) - - parser.add_argument("--amp", default="none") - parser.add_argument("--channels_last", action="store_true") - parser.add_argument("--compile", action="store_true") - - parser.add_argument("--n_epochs", type=int, default=10) - parser.add_argument("--batch_size", type=int, default=64) - parser.add_argument("--n_workers", type=int, default=4) - - parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys()) - parser.add_argument("--lr", type=float, default=1e-4) - parser.add_argument("--weight_decay", type=float, default=0) - parser.add_argument("--cosine_lr_scheduler", action="store_true") - - parser.add_argument("--project") - parser.add_argument("--run_name", default="debug") - parser.add_argument("--profile", action="store_true") - parser.add_argument("--seed", type=int) - return parser - - -def get_dloader(args, training: bool): - transforms = [v2.ToImage()] - - if training: - transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()]) - else: - transforms.extend([v2.Resize(256), v2.CenterCrop(224)]) - - transforms.append(v2.ToDtype(torch.float32, scale=True)) - transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) - transforms = v2.Compose(transforms) - - # use dataset from HF so download is fast - ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation") - ds = ds.select_columns(["image", "label"]) - ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"])) - - return DataLoader( - ds, - batch_size=args.batch_size, - shuffle=training, - num_workers=args.n_workers, - pin_memory=training, - drop_last=training, - ) - - -def get_amp_ctx(amp): - dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp] - return torch.autocast("cuda", dtype=dtype, enabled=amp != "none") - - -@torch.no_grad() -def evaluate_model(model, args): - model.eval() - val_dloader = get_dloader(args, False) - - all_labels = [] - all_preds = [] - - for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"): - all_labels.append(batch["label"].clone()) - if args.channels_last: - batch["image"] = batch["image"].to(memory_format=torch.channels_last) - - with get_amp_ctx(args.amp): - all_preds.append(model(batch["image"].cuda()).argmax(1).cpu()) - - all_labels = torch.cat(all_labels, dim=0) - all_preds = torch.cat(all_preds, dim=0) - - acc = (all_labels == all_preds).float().mean() - return acc - - -if __name__ == "__main__": - args = get_parser().parse_args() - - if args.profile: - args.n_epochs = 1 - if args.seed is not None: - torch.manual_seed(args.seed) - - for k, v in vars(args).items(): - print(f"{k}: {v}") - - # wandb is only enabled when args.project is set and args.profile is False - logger = WandbLogger(args) - dloader = get_dloader(args, True) - print(f"Train dataset: {len(dloader.dataset):,} images") - - model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda() - if args.channels_last: - model.to(memory_format=torch.channels_last) - if args.compile: - model.compile(fullgraph=True) - print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") - - optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay) - lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) - - grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") - - step = 0 - for epoch_idx in range(args.n_epochs): - model.train() - prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext() - start_time = datetime.datetime.now() - - with prof: - for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"): - if args.channels_last: - batch["image"] = batch["image"].to(memory_format=torch.channels_last) - - with get_amp_ctx(args.amp): - loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda()) - grad_scaler.scale(loss).backward() - - if args.cosine_lr_scheduler: - lr = lr_schedule.get_lr(step) - for param_group in optim.param_groups: - param_group["lr"] = lr - - if step % 100 == 0: - logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step) - - grad_scaler.step(optim) - grad_scaler.update() - optim.zero_grad() - - step += 1 - - if args.profile and step == 20: - break - - if args.profile: - prof.export_chrome_trace("trace.json") - - else: - print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}") - - val_acc = evaluate_model(model, args) - print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") - logger.log(dict(val_acc=val_acc), step=step) - - print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") From 8fc9ba2e9e7a8d7c72cdaeb36d61ea485d1d5819 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 6 Jul 2024 18:43:40 +0800 Subject: [PATCH 6/6] update --- torchao/prototype/low_bit_optim/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index 7a108959f7..5c1d631d2c 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -31,7 +31,7 @@ NOTE: Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). -Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 8, on 4070Ti SUPER, with fixed random seed: +Results for fine-tuning ViT-H (630M params) with BF16 AMP for 2 epochs, batch size 8, on 4070Ti SUPER, with fixed random seed: Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy ---------------|-----------------|--------------------------|----------