From cc77513b5ff05024673cc21cbfc9d1b6ae61604c Mon Sep 17 00:00:00 2001
From: Thien Tran <gau.nernst@yahoo.com.sg>
Date: Sun, 7 Jul 2024 01:04:24 +0800
Subject: [PATCH] Update low-bit Adam benchmark (#481)

* update benchmark

* add rank1 option to lpmm

* add comma

* update readme

* remove unwanted file

* update
---
 benchmarks/benchmark_low_bit_adam.py      | 11 ++++++++---
 torchao/prototype/low_bit_optim/README.md | 23 ++++++++++++-----------
 2 files changed, 20 insertions(+), 14 deletions(-)

diff --git a/benchmarks/benchmark_low_bit_adam.py b/benchmarks/benchmark_low_bit_adam.py
index 63fdeff3dd..716541ee9c 100644
--- a/benchmarks/benchmark_low_bit_adam.py
+++ b/benchmarks/benchmark_low_bit_adam.py
@@ -37,6 +37,7 @@
     Adam8bitAo=Adam8bit,
     Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
     Adam4bitAo=Adam4bit,
+    Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")),
 )
 
 
@@ -92,6 +93,7 @@ def get_parser():
     parser.add_argument("--project")
     parser.add_argument("--run_name", default="debug")
     parser.add_argument("--profile", action="store_true")
+    parser.add_argument("--seed", type=int)
     return parser
 
 
@@ -155,6 +157,8 @@ def evaluate_model(model, args):
 
     if args.profile:
         args.n_epochs = 1
+    if args.seed is not None:
+        torch.manual_seed(args.seed)
 
     for k, v in vars(args).items():
         print(f"{k}: {v}")
@@ -176,11 +180,11 @@ def evaluate_model(model, args):
 
     grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
 
-    start_time = datetime.datetime.now()
     step = 0
     for epoch_idx in range(args.n_epochs):
         model.train()
         prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext()
+        start_time = datetime.datetime.now()
 
         with prof:
             for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"):
@@ -212,9 +216,10 @@ def evaluate_model(model, args):
             prof.export_chrome_trace("trace.json")
 
         else:
+            print(f"Time taken for epoch {epoch_idx + 1}: {(datetime.datetime.now() - start_time)}")
+
             val_acc = evaluate_model(model, args)
             print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
             logger.log(dict(val_acc=val_acc), step=step)
 
-    print(f"Time taken: {(datetime.datetime.now() - start_time)}")
-    print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
+    print(f"Max memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md
index da3cc064f5..5c1d631d2c 100644
--- a/torchao/prototype/low_bit_optim/README.md
+++ b/torchao/prototype/low_bit_optim/README.md
@@ -31,17 +31,18 @@ NOTE:
 
 Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).
 
-Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:
-
-Adam impl  | max memory (GB) | time taken | accuracy
------------|-----------------|------------|----------
-PyTorch    | 12.98           | 10m 08s    | 87.70
-bnb 8-bit  |  8.31           |  8m 38s    | 86.22
-ao 8-bit   |  8.32           | 10m 54s    | 86.67
-lpmm 4-bit |  7.72           |  7m 48s    | 84.70
-ao 4-bit   |  7.72           |  9m 17s    | 85.60
-
-NOTE: time taken includes validation time, and compile time for torchao optimizers.
+Results for fine-tuning ViT-H (630M params) with BF16 AMP for 2 epochs, batch size 8, on 4070Ti SUPER, with fixed random seed:
+
+Adam impl      | max memory (GB) | time taken for 2nd epoch | accuracy
+---------------|-----------------|--------------------------|----------
+PyTorch        | 12.94           |  8m 18s                  | 91.14
+bnb 8-bit      |  8.31           |  6m 50s                  | 90.67
+ao 8-bit       |  8.32           |  9m 04s                  | 90.71
+lpmm 4-bit     |  7.72           |  5m 59s                  | 89.97
+ao 4-bit       |  7.72           |  7m 00s                  | 89.94
+lpmm 4-bit (*) |  7.73           | 11m 10s                  | 89.71
+
+(*) means rank-1 normalization is used for 2nd optimizer state. Refer to [paper](https://arxiv.org/abs/2309.01507) for more details.
 
 ## Credits