Skip to content

Commit

Permalink
intial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Sep 5, 2024
1 parent f5703b0 commit b6fc991
Show file tree
Hide file tree
Showing 12 changed files with 407 additions and 21 deletions.
87 changes: 87 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,33 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torch.testing._internal.common_utils import TestCase

from torch.ao.pruning import WeightNormSparsifier


logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

def apply_fake_block_sparsity(model, **kwargs):
"""
This function simulates 2:4 sparsity on all linear layers in a model.
It uses the torch.ao.pruning flow.
"""
filter_fn = kwargs.pop("filter_fn", _is_linear)
# torch.ao.pruning flow
sparse_config = []
for name, mod in model.named_modules():
if filter_fn(mod, name):
sparse_config.append({"tensor_fqn": f"{name}.weight"})

sparsifier = WeightNormSparsifier(
sparsity_level=0.5, sparse_block_shape=(64, 64)
)
sparsifier.prepare(model, sparse_config)
sparsifier.step()
sparsifier.squash_mask()


class TestSemiStructuredSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
Expand Down Expand Up @@ -73,5 +95,70 @@ def test_quant_semi_sparse(self):

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)



class TestBlockSparseWeight(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((1024, 1024)).half().cuda()
model = (
nn.Sequential(
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
)
.half()
.cuda()
)

from torchao.sparsity.utils import create_block_sparse_tensor
M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
sparsify_(model, block_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

class TestQuantBlockSparseWeight(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((128, 128)).to(torch.bfloat16).cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.to(torch.bfloat16)
.cuda()
)

from torchao.sparsity.utils import create_block_sparse_tensor
M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) * torch.rand(M, N, dtype=torch.bfloat16).cuda()
print(model[0].weight)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
print(model[1].weight)

model_copy = copy.deepcopy(model)

quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType(), ))
sparse_result = model(input)

print(reference)
print(sparse_result)
assert torch.allclose(reference, sparse_result, rtol=1e-2, atol=1e-2)


if __name__ == "__main__":
unittest.main()
19 changes: 10 additions & 9 deletions torchao/_models/sam/benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# baseline
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True
# int8 dynamic quant (all)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant
# 2:4 sparsity (all)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only
# 2:4 sparsity (mlp only)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
# int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --print_header True
## int8 dynamic quant (all)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant
## 2:4 sparsity (all)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse_mlp_only
## 2:4 sparsity (mlp only)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress sparse
## int8 dynamic quant + 2:4 sparsity (attn: int8, mlp lin1: int8+2:4 fuse mul, mlp lin2: 2:4 sparse)
#python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_sparse
python eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 32 --use_compile max-autotune --use_half bfloat16 --device cuda --compress int8_dynamic_quant_block_sparse
6 changes: 5 additions & 1 deletion torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,11 @@ def mlp_only(mod, name):
mlp_lin2_only)
if not TORCH_VERSION_AT_LEAST_2_5:
predictor.model.image_encoder = unwrap_tensor_subclass(predictor.model.image_encoder)

elif compress == "int8_dynamic_quant_block_sparse":
def mlp_only(mod, name):
return isinstance(mod, torch.nn.Linear) and 'mlp' in name
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
quantize_(predictor.model.image_encoder, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType()), mlp_only)
else:
assert compress is None, f"Unsupported compress mode {compress}"

Expand Down
6 changes: 6 additions & 0 deletions torchao/_models/sam/results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ cuda,vit_h,32,15154,18,25.16516896830006,39.73746416166231,0.5818834536577897,ma
cuda,vit_h,32,15632,19,24.824717871078573,40.282431614863405,0.5675837487618974,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.589577947798148,40.66763578142439,0.5306639662569573,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.597207143088742,37.597932543073384,0.5669944616184625,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
device,sam_model_type,batch_size,memory(MiB),memory(%),img_s(avg),batch_ms(avg)/batch_size,mIoU,use_compile,use_half,compress,use_compile_decoder,use_rel_pos,pad_input_image_batch,num_workers,num_batches,num_images,profile_path,memory_path
cuda,vit_h,32,15172,18,22.787559123509425,43.88359431477336,0.5809962729163862,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15153,18,24.872293344547476,40.20537978333312,0.5821541984818872,max-autotune,torch.bfloat16,int8_dynamic_quant,False,True,True,32,154,4928,None,None
cuda,vit_h,32,15640,19,24.64409232721636,40.5776762528853,0.5674436009126148,max-autotune,torch.bfloat16,sparse_mlp_only,False,True,True,32,154,4928,None,None
cuda,vit_h,32,13429,16,24.710537332827382,40.46856555691013,0.530554119734646,max-autotune,torch.bfloat16,sparse,False,True,True,32,154,4928,None,None
cuda,vit_h,32,14869,18,26.5429434697436,37.67479673608557,0.566992236284673,max-autotune,torch.bfloat16,int8_dynamic_quant_sparse,False,True,True,32,154,4928,None,None
Loading

0 comments on commit b6fc991

Please sign in to comment.