Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 dynamic quant + bsr support #821

Merged
merged 17 commits into from
Sep 26, 2024
139 changes: 121 additions & 18 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,24 @@

import torch
from torch import nn

from torchao.sparsity import (
apply_fake_sparsity,
sparsify_,
semi_sparse_weight,
)
from torch.testing._internal import common_utils
from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType
from torchao.quantization.quant_api import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
int4_weight_only,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_
from torchao.utils import TORCH_VERSION_AFTER_2_5, TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_4


logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

class TestSemiStructuredSparse(TestCase):

class TestSemiStructuredSparse(common_utils.TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand All @@ -37,6 +34,7 @@ def test_sparse(self):
)
.half()
.cuda()
.eval()
)

apply_fake_sparsity(model)
Expand All @@ -45,13 +43,17 @@ def test_sparse(self):
sparsify_(model, semi_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

class TestQuantSemiSparse(TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
class TestQuantSemiSparse(common_utils.TestCase):

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quant_semi_sparse(self):
@common_utils.parametrize("compile", [True, False])
def test_quant_semi_sparse(self, compile):
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

input = torch.rand((128, 128)).half().cuda()
model = (
nn.Sequential(
Expand All @@ -60,19 +62,27 @@ def test_quant_semi_sparse(self):
)
.half()
.cuda()
.eval()
)
apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)
quantize_(model_copy, int8_dynamic_activation_int8_weight())
dense_result = model_copy(input)

quantize_(model, int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
quantize_(
model,
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
)
if compile:
model = torch.compile(model)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-2, atol=1e-2)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "pytorch 2.5+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse_marlin(self):
@common_utils.parametrize("compile", [True, False])
def test_sparse_marlin(self, compile):
input = torch.rand((256, 256)).half().cuda()
model = (
nn.Sequential(
Expand All @@ -81,6 +91,7 @@ def test_sparse_marlin(self):
)
.half()
.cuda()
.eval()
)

apply_fake_sparsity(model)
Expand All @@ -92,9 +103,101 @@ def test_sparse_marlin(self):

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
if compile:
model = torch.compile(model)
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
torch.testing.assert_close(dense_result, sparse_result, atol=3e-1, rtol=3e-1)


class TestBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "pytorch 2.4+ feature due to need for custom op support")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
def test_sparse(self, compile):
input = torch.rand((1024, 1024)).half().cuda()
model = (
nn.Sequential(
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
)
.half()
.cuda()
.eval()
)

from torchao.sparsity.utils import create_block_sparse_tensor

M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity.prototype.superblock.blocksparse import (
block_sparse_weight,
)

sparsify_(model, block_sparse_weight(blocksize=64))
# if compile:
# model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


class TestQuantBlockSparseWeight(common_utils.TestCase):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_5, "pytorch 2.6+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
def test_sparse(self, compile):
input = torch.rand((256, 128)).to(torch.bfloat16).cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.to(torch.bfloat16)
.cuda()
.eval()
)
from torchao.sparsity.prototype.superblock.blocksparse import (
blocksparse_int_addmm,
)
from torchao.sparsity.utils import create_block_sparse_tensor

M, N = model[0].weight.shape
model[0].weight.data = (
create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
* torch.rand(M, N, dtype=torch.bfloat16).cuda()
)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)

model_copy = copy.deepcopy(model)

quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType

quantize_(
model,
int8_dynamic_activation_int8_weight(
layout_type=BlockSparseLayoutType(blocksize=64)
),
)
if compile:
model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)


common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)
common_utils.instantiate_parametrized_tests(TestBlockSparseWeight)
common_utils.instantiate_parametrized_tests(TestQuantBlockSparseWeight)

if __name__ == "__main__":
unittest.main()
Loading
Loading