-
Notifications
You must be signed in to change notification settings - Fork 227
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add semi-structured sparse + dynamic int8 subclasses (#36)
This PR adds in int8 dynamic quantization + semi-structured sparsity support into torchao.
- Loading branch information
Showing
6 changed files
with
541 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import pandas as pd | ||
import torch | ||
from segment_anything import sam_model_registry | ||
from torch.utils.benchmark import Timer | ||
from torch.sparse import SparseSemiStructuredTensor, SparseSemiStructuredTensorCUTLASS, SparseSemiStructuredTensorCUSPARSELT | ||
from torchao.quantization.quant_api import ( | ||
_replace_with_custom_fn_if_matches_filter, | ||
_get_subclass_inserter, | ||
_is_linear, | ||
QuantizedLinearWeightBase, | ||
Int8DynamicallyQuantizedLinearWeight, | ||
) | ||
from torchao.quantization import change_linear_weights_to_int8_dqtensors | ||
from torchao.sparsity import ( | ||
apply_sparse_semi_structured, | ||
apply_fake_sparsity, | ||
) | ||
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight | ||
from itertools import product | ||
from tqdm import tqdm | ||
|
||
sam_checkpoint_base_path = "/home/jessecai/local/MODELS" | ||
model_type = 'vit_h' | ||
model_name = 'sam_vit_h_4b8939.pth' | ||
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}" | ||
|
||
torch._inductor.config.epilogue_fusion = True | ||
torch._inductor.config.coordinate_descent_tuning = True | ||
torch._inductor.config.coordinate_descent_check_all_directions = True | ||
torch._inductor.config.force_fuse_int_mm_with_mul = True | ||
|
||
@torch.no_grad() | ||
def benchmark(f, *args, **kwargs): | ||
for _ in range(3): | ||
f(*args, **kwargs) | ||
torch.cuda.synchronize() | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
t0 = Timer( | ||
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | ||
) | ||
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20) | ||
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9} | ||
|
||
def get_sam_model(only_one_block=False, batchsize=1): | ||
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda() | ||
model = sam.image_encoder.eval() | ||
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda') | ||
|
||
# code to use just a single block of the model | ||
if only_one_block: | ||
model = model.blocks[0] | ||
image = torch.randn(batchsize, 64, 64, 1280, device='cuda') | ||
return model, image | ||
|
||
def qkv_only(mod, name): | ||
return isinstance(mod, torch.nn.Linear) and 'qkv' in name | ||
|
||
def proj_only(mod, name): | ||
return isinstance(mod, torch.nn.Linear) and 'proj' in name | ||
|
||
def lin1_only(mod, name): | ||
return isinstance(mod, torch.nn.Linear) and 'lin1' in name | ||
|
||
def lin2_only(mod, name): | ||
return isinstance(mod, torch.nn.Linear) and 'lin2' in name | ||
|
||
SUBCLASSES = { | ||
"quant" : Int8DynamicallyQuantizedLinearWeight, | ||
"quant+sparse (cutlass)" : Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight, | ||
"quant+sparse (cusparselt)" : Int8DynamicallyQuantized24CusparseltLinearFuseMulWeight, | ||
"sparse (cutlass)" : SparseSemiStructuredTensorCUTLASS, | ||
"sparse (cusparselt)" : SparseSemiStructuredTensorCUSPARSELT, | ||
} | ||
|
||
def run_once(block_only=False, dtype=torch.bfloat16, batchsize=32, compile=True, qkv=None, proj=None, lin1=None, lin2=None): | ||
res = { | ||
"block_only": block_only, | ||
"batchsize": batchsize, | ||
"dtype": dtype, | ||
"compile": compile, | ||
"qkv" : qkv, | ||
"proj": proj, | ||
"lin1": lin1, | ||
"lin2": lin2, | ||
} | ||
with torch.no_grad(): | ||
model, image = get_sam_model(block_only, batchsize) | ||
model = model.to(dtype) | ||
image = image.to(dtype) | ||
|
||
# 2:4 prune model | ||
apply_fake_sparsity(model) | ||
option_and_filter_fn = zip([qkv, proj, lin1, lin2], [qkv_only, proj_only, lin1_only, lin2_only]) | ||
|
||
for option, filter_fn in option_and_filter_fn: | ||
subclass = SUBCLASSES.get(option, None) | ||
if subclass and issubclass(subclass, SparseSemiStructuredTensor): | ||
# replace with to_sparse_semi_structured | ||
for name, mod in model.named_modules(): | ||
if filter_fn(mod, name): | ||
mod.weight = torch.nn.Parameter(subclass.from_dense(mod.weight)) | ||
elif subclass and issubclass(subclass, QuantizedLinearWeightBase): | ||
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(subclass), filter_fn) | ||
|
||
if compile: | ||
model = torch.compile(model, mode='max-autotune') | ||
|
||
res.update(benchmark(model, image)) | ||
res["img/s"] = 1 / (res['time'] / 1000 / res['batchsize']) | ||
return res | ||
|
||
if __name__ == "__main__": | ||
print("BENCHMARKING") | ||
ALL_RUNS = [run_once(qkv="quant+sparse (cutlass)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)")] | ||
# for option in tqdm(SUBCLASSES)] | ||
# ALL_RUNS = [ | ||
# run_once(), | ||
# run_once(qkv="quant", proj="quant", lin1="quant", lin2="quant"), | ||
# run_once(qkv="quant+sparse (cusparselt)", proj="quant+sparse (cusparselt)", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cutlass)"), | ||
# run_once(qkv="quant+sparse (cusparselt)", proj="quant", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), | ||
# run_once(qkv="quant", proj="quant", lin1="quant+sparse (cusparselt)", lin2="quant+sparse (cusparselt)"), | ||
# run_once(qkv="sparse (cusparselt)", proj="sparse (cusparselt)", lin1="sparse (cusparselt)", lin2="sparse (cusparselt)"), | ||
# run_once(qkv="sparse (cutlass)", proj="sparse (cutlass)", lin1="sparse (cutlass)", lin2="sparse (cutlass)"), | ||
# run_once(qkv="quant+sparse (cutlass)", proj="quant+sparse (cutlass)", lin1="quant+sparse (cutlass)", lin2="quant+sparse (cutlass)"), | ||
# ] | ||
df = pd.DataFrame(ALL_RUNS) | ||
df.to_csv("sam_benchmark_results.csv") | ||
print(df) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
,block_only,batchsize,dtype,compile,qkv,proj,lin1,lin2,time,memory,img/s | ||
0,False,32,torch.bfloat16,True,,,,,1457.0417301729321,28.280423936,21.96230851686177 | ||
1,False,32,torch.bfloat16,True,quant,quant,quant,quant,1318.5919532552361,28.261341696,24.268311300551254 | ||
2,False,32,torch.bfloat16,True,quant+sparse (cusparselt),quant,quant+sparse (cutlass),quant+sparse (cutlass),1253.1237555667758,28.18694656,25.536184960061433 | ||
3,False,32,torch.bfloat16,True,quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),quant+sparse (cutlass),1290.4946617782116,27.837008896,24.796693041648258 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import logging | ||
import unittest | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from torchao.sparsity import apply_fake_sparsity, apply_sparse_semi_structured | ||
from torchao.sparsity.prototype.dynamic_quant_sparse import Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight | ||
from torchao.quantization.quant_api import ( | ||
_replace_with_custom_fn_if_matches_filter, | ||
_get_subclass_inserter, | ||
_is_linear, | ||
) | ||
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 | ||
from torch.testing._internal.common_utils import TestCase | ||
|
||
|
||
logging.basicConfig( | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | ||
) | ||
|
||
class TestSemiStructuredSparse(TestCase): | ||
|
||
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_sparse(self): | ||
input = torch.rand((128, 128)).half().cuda() | ||
model = ( | ||
nn.Sequential( | ||
nn.Linear(128, 256), | ||
nn.Linear(256, 128), | ||
) | ||
.half() | ||
.cuda() | ||
) | ||
|
||
apply_fake_sparsity(model) | ||
dense_result = model(input) | ||
|
||
apply_sparse_semi_structured(model) | ||
sparse_result = model(input) | ||
|
||
assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
class TestQuantSemiSparse(TestCase): | ||
|
||
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "pytorch 2.3+ feature") | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_quant_semi_sparse(self): | ||
input = torch.rand((128, 128)).half().cuda() | ||
model = ( | ||
nn.Sequential( | ||
nn.Linear(128, 256), | ||
nn.Linear(256, 128), | ||
) | ||
.half() | ||
.cuda() | ||
) | ||
|
||
apply_fake_sparsity(model) | ||
dense_result = model(input) | ||
|
||
_replace_with_custom_fn_if_matches_filter(model, _get_subclass_inserter(Int8DynamicallyQuantizedSemiStructuredSparseLinearWeight), _is_linear) | ||
sparse_result = model(input) | ||
|
||
assert torch.allclose(dense_result, sparse_result, rtol=1e-1, atol=1e-1) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.