Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip committed Sep 25, 2024
1 parent 7dff17a commit 9a0e918
Show file tree
Hide file tree
Showing 8 changed files with 465 additions and 46 deletions.
63 changes: 63 additions & 0 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torchao.sparsity import (
apply_fake_sparsity,
apply_fake_block_sparsity,
sparsify_,
semi_sparse_weight,
)
Expand Down Expand Up @@ -96,5 +97,67 @@ def test_sparse_marlin(self):

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"

class TestBlockSparseWeight(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((1024, 1024)).half().cuda()
model = (
nn.Sequential(
nn.Linear(1024, 2048),
nn.Linear(2048, 1024),
)
.half()
.cuda()
)

from torchao.sparsity.utils import create_block_sparse_tensor
M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
dense_result = model(input)

from torchao.sparsity.prototype.superblock.blocksparse import block_sparse_weight
sparsify_(model, block_sparse_weight())
sparse_result = model(input)

assert torch.allclose(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

class TestQuantBlockSparseWeight(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "pytorch 2.3+ feature")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_sparse(self):
input = torch.rand((128, 128)).to(torch.bfloat16).cuda()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.to(torch.bfloat16)
.cuda()
)

from torchao.sparsity.utils import create_block_sparse_tensor
M, N = model[0].weight.shape
model[0].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16) * torch.rand(M, N, dtype=torch.bfloat16).cuda()
print(model[0].weight)
M, N = model[1].weight.shape
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.bfloat16)
print(model[1].weight)

model_copy = copy.deepcopy(model)

quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
quantize_(model, int8_dynamic_activation_int8_weight(layout_type=BlockSparseLayoutType(), ))
sparse_result = model(input)

print(reference)
print(sparse_result)
assert torch.allclose(reference, sparse_result, rtol=1e-2, atol=1e-2)

if __name__ == "__main__":
unittest.main()
201 changes: 200 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from torchao.float8.inference import Float8MMConfig
aten = torch.ops.aten


###############################
# Base Layout Tensor Subclass #
###############################
Expand Down Expand Up @@ -472,6 +471,13 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor:
temp.view(-1, 4).scatter_(1, pruning_inds, value=0)
return temp

@dataclass(frozen=True)
class BlockSparseLayoutType(LayoutType):

def pre_process(self, input: torch.Tensor) -> torch.Tensor:
return input



@dataclass(frozen=True)
class TensorCoreTiledLayoutType(LayoutType):
Expand Down Expand Up @@ -669,6 +675,162 @@ def from_plain(
int_data_compressed = torch._cslt_compress(int_data)
return cls(int_data_compressed, scale, zero_point, layout_type)

@register_layout_cls(BlockSparseLayoutType)
class BlockSparseAQTLayout(PlainAQTLayout):
quantized_linear_impl = "block"
bsr_crow_indices: Optional[torch.Tensor]
bsr_col_indices: Optional[torch.Tensor]
bsr_values: Optional[torch.Tensor]
scale: Optional[torch.Tensor]
zero_point: Optional[torch.Tensor]

__slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"]

@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
layout_type: LayoutType,
requires_grad: bool = False,
):
if bsr_values is None:
raise ValueError("bsr values must be provided!")
else:
previous_tensor = bsr_values

kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__( # noqa: PYI034
self,
shape: torch.Size,
bsr_crow_indices: Optional[torch.Tensor],
bsr_col_indices: Optional[torch.Tensor],
bsr_values: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
zero_point: Optional[torch.Tensor],
layout_type: LayoutType,
requires_grad: bool = False,
):
self.bsr_crow_indices = bsr_crow_indices
self.bsr_col_indices = bsr_col_indices
self.bsr_values = bsr_values
self.scale = scale
self.zero_point = zero_point
self.layout_type = layout_type

def __repr__(self) -> str: # type: ignore[override]
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"

def __tensor_flatten__(self):
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (self.shape, self.layout_type, self.requires_grad)
return inner_tensors, tensor_meta

@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta: Tuple[torch.Size, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, layout_type, requires_grad = tensor_meta
return cls(
shape=shape,
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
bsr_values=inner_tensors.get("bsr_values", None),
scale=inner_tensors.get("scale", None),
zero_point=inner_tensors.get("zero_point", None),
layout_type=layout_type,
requires_grad=requires_grad,
)

@classmethod
def from_plain(cls, int_data, scale, zero_point, layout_type):
bsr_tensor = int_data.to_sparse_bsr(64)
return cls(
shape=int_data.shape,
bsr_crow_indices=bsr_tensor.crow_indices(),
bsr_col_indices=bsr_tensor.col_indices(),
bsr_values=bsr_tensor.values(),
scale=scale,
zero_point=zero_point,
layout_type = layout_type,
requires_grad=False,
)

@torch._dynamo.disable
def get_plain(self):
# asdf = torch.eye(self.shape[1]).to(self.device)
# self_bsr = torch.sparse_bsr_tensor(
# self.crow_indices().to(self.device),
# self.col_indices().to(self.device),
# self.values().to(self.device),
# size=(self.shape[0], self.shape[1])).to(self.dtype)
# int_data_bsr = bsr_dense_mm(self_bsr, asdf)
return torch.zeros(self.shape, device=self.device).to(self.dtype), self.scale, self.zero_point

def _apply_fn_to_data(self, func):
return self.__class__(
shape = self.shape,
bsr_crow_indices=func(self.bsr_crow_indices),
bsr_col_indices=func(self.bsr_col_indices),
bsr_values=func(self.bsr_values),
scale=self.scale,
zero_point=self.zero_point,
layout_type=self.layout_type,
requires_grad=self.requires_grad,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
args[0].transposed = not args[0].transposed
return return_and_correct_aliasing(func, args, kwargs, args[0])

if func is aten.crow_indices.default:
return args[0].bsr_crow_indices.detach()

if func is aten.col_indices.default:
return args[0].bsr_col_indices.detach()

if func is aten.values.default:
return args[0].bsr_values.detach()

if func is aten._nnz.default:
return args[0].bsr_values.shape[0]

raise NotImplementedError(
f"BlockSparseAQTLayout dispatch: attempting to run {func}, this is not supported"
)

@register_layout_cls(MarlinSparseLayoutType)
class MarlinSparseAQTLayout(AQTLayout):
Expand Down Expand Up @@ -1221,6 +1383,42 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh
y += bias
return y

def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor) and
_aqt_is_int8_reduced_range(input_tensor) and
isinstance(weight_tensor, AffineQuantizedTensor) and
weight_tensor.is_cuda and
input_tensor.dtype == weight_tensor.dtype and
isinstance(input_tensor.layout_type, PlainLayoutType) and
isinstance(weight_tensor.layout_type, BlockSparseLayoutType) and
weight_tensor.layout_tensor.quantized_linear_impl == "block"
)


def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias):
x_vals_int8 = input_tensor.layout_tensor.int_data
x_scales = input_tensor.layout_tensor.scale
w_vals = weight_tensor.layout_tensor
w_scales = weight_tensor.layout_tensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
tmp_t = tmp.t()

y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(),
w_vals.col_indices(),
w_vals.values(),
tmp_t,
w_scales,
x_scales.reshape(-1))
y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
y = y.reshape(*y_shape)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
return (
# input is native bfloat16 tensor
Expand Down Expand Up @@ -1473,6 +1671,7 @@ def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl),
(_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl),
(_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl),
(_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl),
(_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl),
Expand Down
2 changes: 2 additions & 0 deletions torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .utils import PerChannelNormObserver # noqa: F403
from .sparse_api import (
apply_fake_sparsity,
apply_fake_block_sparsity,
sparsify_,
semi_sparse_weight,
int8_dynamic_activation_int8_semi_sparse_weight
Expand All @@ -17,6 +18,7 @@
"WandaSparsifier",
"PerChannelNormObserver",
"apply_fake_sparsity",
"apply_fake_block_sparsity",
"sparsify_"
"semi_sparse_weight",
"int8_dynamic_activation_int8_semi_sparse_weight"
Expand Down
26 changes: 21 additions & 5 deletions torchao/sparsity/prototype/superblock/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import utils
from torch import nn
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
from torch.sparse._triton_ops_meta import dump as store_tuned_kernel_params
from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity
from torchao.utils import benchmark_model, profiler_runner

Expand All @@ -34,15 +35,30 @@ def main(args):
# BSR kernel tuning
if args.bsr and args.tune_kernel_params:
print("Tuning kernel params")
kwargs = dict(
dtype=torch.int8 if args.quantization else dtype,
sparsity=args.sparsity_linear, verbose=True,
# per blocksparse_int_addmm:
alpha=1, beta=0, use_left_alpha=True, use_right_alpha=True,
# force tuning because existing tuning parameters are
# computed for use_left/right_alpha=False, however, it
# turns out that re-tuning for use_left/right_alpha=False
# leads to the same set of tuning parametes:
# force=True
)
if args.model == "vit_b_16":
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, **kwargs)
optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, **kwargs)
elif args.model == "vit_h_14":
optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, **kwargs)
optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, **kwargs)
else:
raise NotImplementedError("Tuning kernel params for this model is not supported yet.")

# Warning: the following call will overwrite the source code
# of torch.sparse._triton_ops_meta (hence it is commented out
# by default) but when used, it'll enables reusing the tuned
# parameters in subsequent runs of this script:
# store_tuned_kernel_params()
print("Creating model")
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)

Expand Down
Loading

0 comments on commit 9a0e918

Please sign in to comment.