Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fp8_linear_quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 27, 2024
2 parents a73180d + ed76e9c commit dc1a233
Show file tree
Hide file tree
Showing 45 changed files with 1,472 additions and 280 deletions.
1 change: 1 addition & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

- name: CPU 2.3
runs-on: linux.4xlarge
torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu'
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ If you find the torchao library useful, please cite it in your work as below.
@software{torchao,
title = {torchao: PyTorch native quantization and sparsity for training and inference},
author = {torchao maintainers and contributors},
url = {https//github.com/pytorch/torchao},
url = {https://github.com/pytorch/torchao},
license = {BSD-3-Clause},
month = oct,
year = {2024}
Expand Down
88 changes: 63 additions & 25 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - lpmm (4-bit optim): pip install yacs git+https://github.com/thu-ml/low-bit-optimizers.git
# - DeepSpeed (ZeRO-Offload):
# sudo apt install libopenmpi-dev
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4p
# LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu pip install mpi4py
# DS_BUILD_CPU_ADAM=1 pip install deepspeed --no-cache-dir
#
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default AdamW optimizer from PyTorch core
Expand All @@ -31,11 +31,15 @@
import torch.nn.functional as F
import wandb
from torch.utils.data import DataLoader
from torchao.utils import get_available_devices
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype import low_bit_optim

_DEVICE = get_available_devices()[-1]
assert _DEVICE in ["cuda", "xpu"], "Benchmark currently only supports CUDA & XPU(BF16)"

OPTIM_MAP = dict(
AdamW=partial(torch.optim.AdamW, fused=True),
AdamW8bitBnb=bnb.optim.AdamW8bit,
Expand All @@ -49,7 +53,9 @@

OPTIM_MAP.update(
AdamW4bitLpmm=partial(lpmm.optim.AdamW, fused=True),
AdamW4bitRank1Lpmm=partial(lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")),
AdamW4bitRank1Lpmm=partial(
lpmm.optim.AdamW, qconfig=argparse.Namespace(scale_type="rank1")
),
)

except ImportError:
Expand All @@ -67,8 +73,12 @@ def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
return self.lr * step / self.warmup_steps
if step < self.total_steps:
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
progress = (step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (
1 + math.cos(progress * math.pi)
)
return self.final_lr


Expand All @@ -92,7 +102,9 @@ def get_parser():
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--optim_kwargs", type=json.loads, default=dict())
parser.add_argument("--cosine_lr_scheduler", action="store_true")
parser.add_argument("--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"])
parser.add_argument(
"--optim_cpu_offload", choices=["ao", "ao_offload_grads", "deepspeed"]
)

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
Expand All @@ -110,11 +122,15 @@ def get_dloader(args, training: bool):
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])

transforms.append(v2.ToDtype(torch.float32, scale=True))
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
transforms.append(
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
transforms = v2.Compose(transforms)

# use dataset from HF so download is fast
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
ds = datasets.load_dataset(
"timm/resisc45", split="train" if training else "validation"
)
ds = ds.select_columns(["image", "label"])
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))

Expand All @@ -128,9 +144,9 @@ def get_dloader(args, training: bool):
)


def get_amp_ctx(amp):
def get_amp_ctx(amp, device):
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")
return torch.autocast(device, dtype=dtype, enabled=amp != "none")


@torch.no_grad()
Expand All @@ -148,8 +164,8 @@ def evaluate_model(model, args):
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())
with get_amp_ctx(args.amp, _DEVICE):
all_preds.append(model(batch["image"].to(_DEVICE)).argmax(1).cpu())

all_labels = torch.cat(all_labels, dim=0)
all_preds = torch.cat(all_preds, dim=0)
Expand All @@ -164,8 +180,12 @@ def evaluate_model(model, args):
if args.full_bf16:
assert args.amp == "none", "When --full_bf16 is set, --amp must be none"
if args.optim_cpu_offload == "deepspeed":
assert args.amp == "none", "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert args.optim == "AdamW", "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
assert (
args.amp == "none"
), "When using DeepSpeed ZeRO-Offload, --amp must be none"
assert (
args.optim == "AdamW"
), "When using DeepSpeed ZeRO-Offload, --optim must be AdamW"
if args.profile:
args.n_epochs = 1
if args.seed is not None:
Expand All @@ -185,14 +205,16 @@ def evaluate_model(model, args):
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

model = timm.create_model(args.model, pretrained=True, num_classes=45, **args.model_kwargs)
model = timm.create_model(
args.model, pretrained=True, num_classes=45, **args.model_kwargs
)
if args.checkpoint_activations:
model.set_grad_checkpointing()
if args.full_bf16:
model.bfloat16()
if args.channels_last:
model.to(memory_format=torch.channels_last)
model.cuda() # move model to CUDA after optionally convert it to BF16
model.to(_DEVICE) # move model to DEVICE after optionally convert it to BF16
if args.compile:
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
Expand Down Expand Up @@ -227,9 +249,15 @@ def evaluate_model(model, args):
optim_cls = OPTIM_MAP[args.optim]

if args.optim_cpu_offload == "ao":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls
)
elif args.optim_cpu_offload == "ao_offload_grads":
optim_cls = partial(low_bit_optim.CPUOffloadOptimizer, optimizer_class=optim_cls, offload_gradients=True)
optim_cls = partial(
low_bit_optim.CPUOffloadOptimizer,
optimizer_class=optim_cls,
offload_gradients=True,
)

optim = optim_cls(
model.parameters(),
Expand All @@ -239,24 +267,30 @@ def evaluate_model(model, args):
)

lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
grad_scaler = torch.amp.GradScaler(_DEVICE, enabled=args.amp == "fp16")
log_interval = 10
t0 = time.perf_counter()

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
pbar = tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}")
pbar = tqdm(
dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"
)

with torch.profiler.profile() if args.profile else nullcontext() as prof:
for batch in pbar:
if args.full_bf16:
batch["image"] = batch["image"].bfloat16()
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)
batch["image"] = batch["image"].to(
memory_format=torch.channels_last
)

with get_amp_ctx(args.amp):
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
with get_amp_ctx(args.amp, _DEVICE):
loss = F.cross_entropy(
model(batch["image"].to(_DEVICE)), batch["label"].to(_DEVICE)
)

if args.optim_cpu_offload == "deepspeed":
model.backward(loss)
Expand All @@ -275,7 +309,9 @@ def evaluate_model(model, args):
log_dict = dict(loss=loss.item(), lr=optim.param_groups[0]["lr"])
if step > 0:
t1 = time.perf_counter()
log_dict["imgs_per_second"] = args.batch_size * log_interval / (t1 - t0)
log_dict["imgs_per_second"] = (
args.batch_size * log_interval / (t1 - t0)
)
t0 = t1
logger.log(log_dict, step=step)

Expand All @@ -296,9 +332,11 @@ def evaluate_model(model, args):

else:
val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
print(
f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}"
)
logger.log(dict(val_acc=val_acc), step=step)

peak_mem = torch.cuda.max_memory_allocated() / 1e9
peak_mem = getattr(torch, _DEVICE).max_memory_allocated() / 1e9
print(f"Max memory used: {peak_mem:.02f} GB")
logger.log(dict(max_memory_allocated=peak_mem))
36 changes: 31 additions & 5 deletions examples/sam2_amg_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from server import show_anns
from server import model_type_to_paths
from server import MODEL_TYPES_TO_MODEL
from server import set_fast
from server import set_aot_fast
from server import load_aot_fast
from server import set_furious
from torchao._models.sam2.build_sam import build_sam2
from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from torchao._models.sam2.utils.amg import rle_to_mask
Expand All @@ -19,19 +23,28 @@ def main_docstring():
output_path (str): Path to output image
"""

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False):

def main_headless(checkpoint_path, model_type, input_bytes, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
device = "cuda"
sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type)
if verbose:
print(f"Loading model {sam2_checkpoint} with config {model_cfg}")
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=points_per_batch, output_mode="uncompressed_rle")
image_tensor = file_bytes_to_image_tensor(bytearray(open(input_path, 'rb').read()))
if furious:
set_furious(mask_generator)
if load_fast:
load_aot_fast(mask_generator, load_fast)
if fast:
set_fast(mask_generator, load_fast)

image_tensor = file_bytes_to_image_tensor(input_bytes)
if verbose:
print(f"Loaded image of size {tuple(image_tensor.shape)} and generating mask.")
masks = mask_generator.generate(image_tensor)

# Save an example

if verbose:
print("Generating mask annotations for input image.")
plt.figure(figsize=(image_tensor.shape[1]/100., image_tensor.shape[0]/100.), dpi=100)
plt.imshow(image_tensor)
show_anns(masks, rle_to_mask)
Expand All @@ -40,8 +53,21 @@ def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=
buf = BytesIO()
plt.savefig(buf, format=output_format)
buf.seek(0)
return buf.getvalue()

def main(checkpoint_path, model_type, input_path, output_path, points_per_batch=1024, output_format='png', verbose=False, fast=False, furious=False, load_fast=""):
input_bytes = bytearray(open(input_path, 'rb').read())
output_bytes = main_headless(checkpoint_path,
model_type,
input_bytes,
points_per_batch=points_per_batch,
output_format=output_format,
verbose=verbose,
fast=fast,
furious=furious,
load_fast=load_fast)
with open(output_path, "wb") as file:
file.write(buf.getvalue())
file.write(output_bytes)

main.__doc__ = main_docstring()
if __name__ == "__main__":
Expand Down
Loading

0 comments on commit dc1a233

Please sign in to comment.