diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py index 63fdeff3dd..716541ee9c 100644 --- a/benchmarks/benchmark_low_bit_adam.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -37,6 +37,7 @@ Adam8bitAo=Adam8bit, Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), Adam4bitAo=Adam4bit, + Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")), ) @@ -92,6 +93,7 @@ def get_parser(): parser.add_argument("--project") parser.add_argument("--run_name", default="debug") parser.add_argument("--profile", action="store_true") + parser.add_argument("--seed", type=int) return parser @@ -155,6 +157,8 @@ def evaluate_model(model, args): if args.profile: args.n_epochs = 1 + if args.seed is not None: + torch.manual_seed(args.seed) for k, v in vars(args).items(): print(f"{k}: {v}") @@ -176,11 +180,11 @@ def evaluate_model(model, args): grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") - start_time = datetime.datetime.now() step = 0 for epoch_idx in range(args.n_epochs): model.train() prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext() + start_time = datetime.datetime.now() with prof: for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"): @@ -212,9 +216,10 @@ def evaluate_model(model, args): prof.export_chrome_trace("trace.json") else: + print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}") + val_acc = evaluate_model(model, args) print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") logger.log(dict(val_acc=val_acc), step=step) - print(f"Time taken: {(datetime.datetime.now() - start_time)}") - print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") + print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md index da3cc064f5..5c1d631d2c 100644 --- a/torchao/prototype/low_bit_optim/README.md +++ b/torchao/prototype/low_bit_optim/README.md @@ -31,17 +31,18 @@ NOTE: Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). -Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER: - -Adam impl | max memory (GB) | time taken | accuracy ------------|-----------------|------------|---------- -PyTorch | 12.98 | 10m 08s | 87.70 -bnb 8-bit | 8.31 | 8m 38s | 86.22 -ao 8-bit | 8.32 | 10m 54s | 86.67 -lpmm 4-bit | 7.72 | 7m 48s | 84.70 -ao 4-bit | 7.72 | 9m 17s | 85.60 - -NOTE: time taken includes validation time, and compile time for torchao optimizers. +Results for fine-tuning ViT-H (630M params) with BF16 AMP for 2 epochs, batch size 8, on 4070Ti SUPER, with fixed random seed: + +Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy +---------------|-----------------|--------------------------|---------- +PyTorch | 12.94 | 8m 18s | 91.14 +bnb 8-bit | 8.31 | 6m 50s | 90.67 +ao 8-bit | 8.32 | 9m 04s | 90.71 +lpmm 4-bit | 7.72 | 5m 59s | 89.97 +ao 4-bit | 7.72 | 7m 00s | 89.94 +lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71 + +(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details. ## Credits