Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into int8_woq_scale_type
Browse files Browse the repository at this point in the history
  • Loading branch information
Valentine233 authored Jul 27, 2024
2 parents ea32965 + afde175 commit 7b92973
Show file tree
Hide file tree
Showing 26 changed files with 624 additions and 690 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
include:
- name: CUDA 2.2.2
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: 'torch==2.2.2'
torch-spec: 'torch==2.2.2 "numpy<2" '
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CUDA 2.3
Expand All @@ -38,7 +38,7 @@ jobs:
gpu-arch-version: "12.1"
- name: CPU 2.2.2
runs-on: linux.4xlarge
torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu'
torch-spec: 'torch==2.2.2 --index-url https://download.pytorch.org/whl/cpu "numpy<2" '
gpu-arch-type: "cpu"
gpu-arch-version: ""
- name: CPU 2.3
Expand Down
12 changes: 5 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,18 @@ And a quick crash course on inference quantization to help parse the above table

Sparsifying your model is also a 1 liner that should work on any model with an `nn.Linear`. We find that sparsity works best on compute bound models like SAM, specifically the MLP layers.
```python
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured
from torchao.sparsity import sparsify, semi_sparse_weight()

m = sparsify(m, to_sparse_semi_structured)
m = sparsify_(m, semi_sparse_weight())
```
Sparsity can also be composed with int8 dynamic quantization for further speedups:

```python
from torchao.sparsity import sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.sparsity import sparsify, int8_dynamic_activation_int8_semi_sparse_weight

m = sparsify(m, int8_dynamic_activation_int8_2x4_sparse_weight())
m = sparsify_(m, int8_dynamic_activation_int8_semi_sparse_weight())
```
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + 2:4 sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We found that applying int8 dynamic quantization to the attention layers, int8 dynamic quantization + semi sparse (2:4) sparsity to mlp layer 1 and 2:4 sparsity to mlp layer 2 yielded the best configuration.
We were able to provide a **1.16x (22.7 -> 26.5 img/s) speedup over our dense baseline, while maintaining 97.5% (0.581 -> 0.567) of the evaluation accuracy (mIOU)**.

The following benchmarks were ran for [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast) ViT-h on an NVIDIA-A100-80GB, with batch_size=32 and `bfloat16` dtype, with `torch.compile="max_autotune"`:
Expand Down
1 change: 0 additions & 1 deletion scripts/sam/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse

44 changes: 16 additions & 28 deletions scripts/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import time
import resource

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
from torchao.sparsity import sparsify_, apply_fake_sparsity, int8_dynamic_activation_int8_semi_sparse_weight, semi_sparse_weight
from torchao.utils import unwrap_tensor_subclass

torch._dynamo.config.cache_size_limit = 50000

def unbind_jagged(device, data, sizes, offsets):
Expand Down Expand Up @@ -279,30 +283,17 @@ def run(
block.attn.use_rel_pos = use_rel_pos

if compress == "int8_dynamic_quant":
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight())
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)
elif compress == "sparse_mlp_only":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder, filter_fn=mlp_only)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured, filter_fn=mlp_only)
sparsify_(predictor.model.image_encoder, semi_sparse_weight(), filter_fn=mlp_only)
elif compress == "sparse":
from torchao.sparsity import sparsify
from torch.sparse import to_sparse_semi_structured, apply_fake_sparsity
apply_fake_sparsity(predictor.model.image_encoder)
predictor.model.image_encoder = sparsify(predictor.model.image_encoder, to_sparse_semi_structured)
sparsify_(predictor.model.image_encoder, semi_sparse_weight())
elif compress == "int8_dynamic_quant_sparse":
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = False
from torchao.sparsity import sparsify, apply_fake_sparsity
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass

def attn_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'attn' in name
def mlp_lin1_only(mod, name):
Expand All @@ -316,20 +307,17 @@ def mlp_only(mod, name):
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)

quantize_(
predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only
)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_semi_sparse_weight(),
mlp_lin1_only)
sparsify_(predictor.model.image_encoder,
semi_sparse_weight(),
mlp_lin2_only)
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
int8_dynamic_activation_int8_2x4_sparse_weight(),
mlp_lin1_only, prune=False)

predictor.model.image_encoder = sparsify(predictor.model.image_encoder,
to_sparse_semi_structured,
mlp_lin2_only, prune=False)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down Expand Up @@ -413,6 +401,6 @@ def mlp_only(mod, name):
vals = ",".join(map(str, [device, sam_model_type, batch_size, max_memory_allocated_bytes, max_memory_allocated_percentage, img_s, batch_ms_batch_size, mIoU, use_compile,
use_half, compress, use_compile_decoder, use_rel_pos, pad_input_image_batch, num_workers, num_batches, num_images, profile_path, memory_path]))
f.write(vals+"\n")

if __name__ == '__main__':
fire.Fire(run)
10 changes: 5 additions & 5 deletions scripts/sam/results.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
cuda,vit_h,32,15172,18,22.74609667033727,43.96358700541707,0.5811068585673369,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15154,18,24.908711866303545,40.14659631407106,0.5822020528694204,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15632,19,24.806623549763994,40.311814221468836,0.5671732654673084,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.299052218005198,41.15386851422198,0.5305645705002248,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14865,18,26.46342281926203,37.7880067453756,0.5668329259098808,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15172,18,22.533401716616083,44.37856354651513,0.5812715827356921,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
9 changes: 9 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls):
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

base_param = base_optim.param_groups[0]["params"][0]
base_exp_avg = base_optim.state[base_param]["exp_avg"]

fsdp_param = fsdp_optim.param_groups[0]["params"][0]
fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"]
full_fsdp_exp_avg = fsdp_exp_avg.full_tensor()

self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
36 changes: 20 additions & 16 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.subclass import (
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
Expand Down Expand Up @@ -504,8 +506,8 @@ def test_quantized_tensor_subclass_8da4w(self):
example_inputs = m.example_inputs()
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

Expand Down Expand Up @@ -577,8 +579,8 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
quantize_(m, int8_dynamic_activation_int8_weight())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

Expand Down Expand Up @@ -642,17 +644,19 @@ def test_int8wo_quantized_model_to_device(self):
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
def test_int4wo_quantized_model_to_device(self):
# TODO: change initial model to "cpu"
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")

quantize_(m, int4_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to("cuda"),)
m.to(device="cuda")
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)
devices = ["cuda", "cuda:0"]
for device in devices:
m = ToyLinearModel().eval().to(torch.bfloat16).to(device)
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)

quantize_(m, int4_weight_only())
ref = m(*example_inputs)

example_inputs_cuda = (example_inputs[0].to(device),)
m.to(device=device)
cuda_res = m(*example_inputs_cuda)
self.assertEqual(cuda_res.cpu(), ref)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
27 changes: 16 additions & 11 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import copy
import logging
import unittest

import torch
from torch import nn
from torch.sparse import to_sparse_semi_structured

from torchao.sparsity import apply_fake_sparsity, sparsify
from torchao.sparsity.prototype.dynamic_quant_sparse import int8_dynamic_activation_int8_2x4_sparse_weight
from torchao.sparsity import (
apply_fake_sparsity,
sparsify_,
int8_dynamic_activation_int8_semi_sparse_weight,
semi_sparse_weight,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
int8_dynamic_activation_int8_weight,
quantize_,
)
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3, unwrap_tensor_subclass
from torch.testing._internal.common_utils import TestCase


Expand All @@ -38,12 +44,11 @@ def test_sparse(self):
apply_fake_sparsity(model)
dense_result = model(input)

model = sparsify(model, to_sparse_semi_structured)
sparsify_(model, semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantSemiSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
Expand All @@ -58,15 +63,15 @@ def test_quant_semi_sparse(self):
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)
model_copy = copy.deepcopy(model)
quantize_(model_copy, int8_dynamic_activation_int8_weight())
dense_result = model_copy(input)

sparsify(model, int8_dynamic_activation_int8_2x4_sparse_weight())
quantize_(model, int8_dynamic_activation_int8_semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
to_affine_quantized_static,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
)

Expand All @@ -19,5 +20,6 @@
"to_affine_quantized_static",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
]
Loading

0 comments on commit 7b92973

Please sign in to comment.