Skip to content

Commit

Permalink
Add semi-structured sparse + dynamic int8 subclasses (#36)
Browse files Browse the repository at this point in the history
This PR adds in int8 dynamic quantization + semi-structured sparsity support into torchao.
  • Loading branch information
jcaip authored Apr 26, 2024
1 parent 06a2969 commit 739e62d
Show file tree
Hide file tree
Showing 6 changed files with 541 additions and 1 deletion.
129 changes: 129 additions & 0 deletions benchmarks/benchmark_sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pandas as pd
import torch
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
QuantizedLinearWeightBase,
Int8DynamicallyQuantizedLinearWeight,
)
from torchao.quantization import change_linear_weights_to_int8_dqtensors
from torchao.sparsity import (
apply_sparse_semi_structured,
apply_fake_sparsity,
)
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from itertools import product
from tqdm import tqdm

sam_checkpoint_base_path = "/home/jessecai/local/MODELS"
model_type = 'vit_h'
model_name = 'sam_vit_h_4b8939.pth'
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"

torch._inductor.config.epilogue_fusion = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True

@torch.no_grad()
def benchmark(f, *args, **kwargs):
for _ in range(3):
f(*args, **kwargs)
torch.cuda.synchronize()

torch.cuda.reset_peak_memory_stats()
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}

def get_sam_model(only_one_block=False, batchsize=1):
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
model = sam.image_encoder.eval()
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')

# code to use just a single block of the model
if only_one_block:
model = model.blocks[0]
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
return model, image

def qkv_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'qkv' in name

def proj_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'proj' in name

def lin1_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin1' in name

def lin2_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'lin2' in name

SUBCLASSES = {
"quant" : Int8DynamicallyQuantizedLinearWeight,
"quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight,
"quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight,
"sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS,
"sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT,
}

def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None):
res = {
"block_only": block_only,
"batchsize": batchsize,
"dtype": dtype,
"compile": compile,
"qkv" : qkv,
"proj": proj,
"lin1": lin1,
"lin2": lin2,
}
with torch.no_grad():
model, image = get_sam_model(block_only, batchsize)
model = model.to(dtype)
image = image.to(dtype)

# 2:4 prune model
apply_fake_sparsity(model)
option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only])

for option, filter_fn in option_and_filter_fn:
subclass = SUBCLASSES.get(option, None)
if subclass and issubclass(subclass, SparseSemiStructuredTensor):
# replace with to_sparse_semi_structured
for name, mod in model.named_modules():
if filter_fn(mod, name):
mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight))
elif subclass and issubclass(subclass, QuantizedLinearWeightBase):
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn)

if compile:
model = torch.compile(model, mode='max-autotune')

res.update(benchmark(model, image))
res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize'])
return res

if __name__ == "__main__":
print("BENCHMARKING")
ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")]
# for option in tqdm(SUBCLASSES)]
# ALL_RUNS = [
# run_once(),
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"),
# run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"),
# run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"),
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"),
# ]
df = pd.DataFrame(ALL_RUNS)
df.to_csv("sam_benchmark_results.csv")
print(df)
5 changes: 5 additions & 0 deletions benchmarks/sam_benchmark_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s
0,False,32,torch.bfloat16,True,,,,,1457.0417301729321,28.280423936,21.96230851686177
1,False,32,torch.bfloat16,True,quant,quant,quant,quant,1318.5919532552361,28.261341696,24.268311300551254
2,False,32,torch.bfloat16,True,quant+sparse (cusparselt),quant,quant+sparse (cutlass),quant+sparse (cutlass),1253.1237555667758,28.18694656,25.536184960061433
3,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),1290.4946617782116,27.837008896,24.796693041648258
71 changes: 71 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging
import unittest

import torch
from torch import nn

from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_get_subclass_inserter,
_is_linear,
)
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torch.testing._internal.common_utils import TestCase


logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

class TestSemiStructuredSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((128, 128)).half().cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)

apply_sparse_semi_structured(model)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantSemiSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quant_semi_sparse(self):
input = torch.rand((128, 128)).half().cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.cuda()
)

apply_fake_sparsity(model)
dense_result = model(input)

_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1)


if __name__ == "__main__":
unittest.main()
5 changes: 4 additions & 1 deletion torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import apply_sparse_semi_structured, apply_fake_sparsity

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver"
"PerChannelNormObserver",
"apply_sparse_semi_structured",
"apply_fake_sparsity",
]
Loading

0 comments on commit 739e62d

Please sign in to comment.