diff --git a/test/sparsity/test_parametrization.py b/test/sparsity/test_parametrization.py new file mode 100644 index 0000000000..ebcae785d5 --- /dev/null +++ b/test/sparsity/test_parametrization.py @@ -0,0 +1,175 @@ +import logging +import torch +import unittest +from torch import nn +from torch.nn.utils import parametrize +from torch.testing._internal.common_utils import TestCase + +from torchao.sparsity.prototype.sparsifier import utils + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + + +class ModelUnderTest(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.linear = nn.Linear(16, 16, bias=bias) + self.seq = nn.Sequential( + nn.Linear(16, 16, bias=bias), nn.Linear(16, 16, bias=bias) + ) + + # Make sure the weights are not random + self.linear.weight = nn.Parameter(torch.zeros_like(self.linear.weight) + 1.0) + self.seq[0].weight = nn.Parameter(torch.zeros_like(self.seq[0].weight) + 2.0) + self.seq[1].weight = nn.Parameter(torch.zeros_like(self.seq[1].weight) + 3.0) + if bias: + self.linear = nn.Parameter(torch.zeros_like(self.linear.bias) + 10.0) + self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 20.0) + self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 30.0) + + def forward(self, x): + x = self.linear(x) + x = self.seq(x) + return x + + +class TestFakeSparsity(TestCase): + def test_masking_logic(self): + model = nn.Linear(16, 16, bias=False) + model.weight = nn.Parameter(torch.eye(16)) + x = torch.randn(3, 16) + self.assertEqual(torch.mm(x, torch.eye(16)), model(x)) + + mask = torch.zeros(16, 16) + sparsity = utils.FakeSparsity(mask) + parametrize.register_parametrization(model, "weight", sparsity) + + x = torch.randn(3, 16) + self.assertEqual(torch.zeros(3, 16), model(x)) + + def test_weights_parametrized(self): + model = ModelUnderTest(bias=False) + + assert not hasattr(model.linear, "parametrizations") + assert not hasattr(model.seq[0], "parametrizations") + assert not hasattr(model.seq[1], "parametrizations") + mask = torch.eye(16) + parametrize.register_parametrization( + model.linear, "weight", utils.FakeSparsity(mask) + ) + mask = torch.eye(16) + parametrize.register_parametrization( + model.seq[0], "weight", utils.FakeSparsity(mask) + ) + mask = torch.eye(16) + parametrize.register_parametrization( + model.seq[1], "weight", utils.FakeSparsity(mask) + ) + + assert hasattr(model.linear, "parametrizations") + assert parametrize.is_parametrized(model.linear, "weight") + assert hasattr(model.seq[0], "parametrizations") + assert parametrize.is_parametrized(model.linear, "weight") + assert hasattr(model.seq[1], "parametrizations") + assert parametrize.is_parametrized(model.linear, "weight") + + def test_state_dict_preserved(self): + model_save = ModelUnderTest(bias=False) + + mask = torch.eye(16) + parametrize.register_parametrization( + model_save.linear, "weight", utils.FakeSparsity(mask) + ) + mask = torch.eye(16) + parametrize.register_parametrization( + model_save.seq[0], "weight", utils.FakeSparsity(mask) + ) + mask = torch.eye(16) + parametrize.register_parametrization( + model_save.seq[1], "weight", utils.FakeSparsity(mask) + ) + state_dict = model_save.state_dict() + + model_load = ModelUnderTest(bias=False) + mask = torch.zeros(model_load.linear.weight.shape) + parametrize.register_parametrization( + model_load.linear, "weight", utils.FakeSparsity(mask) + ) + mask = torch.zeros(model_load.seq[0].weight.shape) + parametrize.register_parametrization( + model_load.seq[0], "weight", utils.FakeSparsity(mask) + ) + mask = torch.zeros(model_load.seq[1].weight.shape) + parametrize.register_parametrization( + model_load.seq[1], "weight", utils.FakeSparsity(mask) + ) + # Keep this strict, as we are not loading the 'mask' + model_load.load_state_dict(state_dict, strict=False) + + # Check the parametrizations are preserved + assert hasattr(model_load.linear, "parametrizations") + assert parametrize.is_parametrized(model_load.linear, "weight") + assert hasattr(model_load.seq[0], "parametrizations") + assert parametrize.is_parametrized(model_load.linear, "weight") + assert hasattr(model_load.seq[1], "parametrizations") + assert parametrize.is_parametrized(model_load.linear, "weight") + + # Check the weights are preserved + self.assertEqual( + model_save.linear.parametrizations["weight"].original, + model_load.linear.parametrizations["weight"].original, + ) + self.assertEqual( + model_save.seq[0].parametrizations["weight"].original, + model_load.seq[0].parametrizations["weight"].original, + ) + self.assertEqual( + model_save.seq[1].parametrizations["weight"].original, + model_load.seq[1].parametrizations["weight"].original, + ) + + # Check the masks are not preserved in the state_dict + # We store the state_dicts in the sparsifier, not in the model itself. + # TODO: Need to find a clean way of exporting the parametrized model + self.assertNotEqual( + model_save.linear.parametrizations["weight"][0].mask, + model_load.linear.parametrizations["weight"][0].mask, + ) + self.assertNotEqual( + model_save.seq[0].parametrizations["weight"][0].mask, + model_load.seq[0].parametrizations["weight"][0].mask, + ) + self.assertNotEqual( + model_save.seq[1].parametrizations["weight"][0].mask, + model_load.seq[1].parametrizations["weight"][0].mask, + ) + + def test_jit_trace(self): + model = ModelUnderTest(bias=False) + + mask = torch.eye(16) + parametrize.register_parametrization( + model.linear, "weight", utils.FakeSparsity(mask) + ) + mask = torch.eye(16) + parametrize.register_parametrization( + model.seq[0], "weight", utils.FakeSparsity(mask) + ) + mask = torch.eye(16) + parametrize.register_parametrization( + model.seq[1], "weight", utils.FakeSparsity(mask) + ) + + # Tracing + example_x = torch.ones(3, 16) + model_trace = torch.jit.trace_module(model, {"forward": example_x}) + + x = torch.randn(3, 16) + y = model(x) + y_hat = model_trace(x) + self.assertEqual(y_hat, y) + +if __name__ == "__main__": + unittest.main() diff --git a/test/sparsity/test_scheduler.py b/test/sparsity/test_scheduler.py new file mode 100644 index 0000000000..0cfc898dcd --- /dev/null +++ b/test/sparsity/test_scheduler.py @@ -0,0 +1,194 @@ +import warnings +import unittest + +from torch import nn +from torch.testing._internal.common_utils import TestCase + +from torchao.sparsity.prototype import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier + +class ImplementedScheduler(BaseScheduler): + def get_sl(self): + if self.last_epoch > 0: + return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups] + else: + return list(self.base_sl) + + +class TestScheduler(TestCase): + def test_constructor(self): + model = nn.Sequential(nn.Linear(16, 16)) + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=None) + scheduler = ImplementedScheduler(sparsifier) + + assert scheduler.sparsifier is sparsifier + assert scheduler._step_count == 1 + assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]] + + def test_order_of_steps(self): + """Checks if the warning is thrown if the scheduler step is called + before the sparsifier step""" + + model = nn.Sequential(nn.Linear(16, 16)) + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=None) + scheduler = ImplementedScheduler(sparsifier) + + # Sparsifier step is not called + with self.assertWarns(UserWarning): + scheduler.step() + + # Correct order has no warnings + # Note: This will trigger if other warnings are present. + with warnings.catch_warnings(record=True) as w: + sparsifier.step() + scheduler.step() + # Make sure there is no warning related to the base_scheduler + for warning in w: + fname = warning.filename + fname = "/".join(fname.split("/")[-5:]) + assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py" + + def test_step(self): + model = nn.Sequential(nn.Linear(16, 16)) + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=None) + assert sparsifier.groups[0]["sparsity_level"] == 0.5 + scheduler = ImplementedScheduler(sparsifier) + assert sparsifier.groups[0]["sparsity_level"] == 0.5 + + sparsifier.step() + scheduler.step() + assert sparsifier.groups[0]["sparsity_level"] == 0.25 + + def test_lambda_scheduler(self): + model = nn.Sequential(nn.Linear(16, 16)) + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=None) + assert sparsifier.groups[0]["sparsity_level"] == 0.5 + scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10) + assert sparsifier.groups[0]["sparsity_level"] == 0.0 # Epoch 0 + scheduler.step() + assert sparsifier.groups[0]["sparsity_level"] == 5.0 # Epoch 1 + + +class TestCubicScheduler(TestCase): + def setUp(self): + self.model_sparse_config = [ + {"tensor_fqn": "0.weight", "sparsity_level": 0.8}, + {"tensor_fqn": "2.weight", "sparsity_level": 0.4}, + ] + self.sorted_sparse_levels = [ + conf["sparsity_level"] for conf in self.model_sparse_config + ] + self.initial_sparsity = 0.1 + self.initial_step = 3 + + def _make_model(self, **kwargs): + model = nn.Sequential( + nn.Linear(13, 17), + nn.Dropout(0.5), + nn.Linear(17, 3), + ) + return model + + def _make_scheduler(self, model, **kwargs): + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=self.model_sparse_config) + + scheduler_args = { + "init_sl": self.initial_sparsity, + "init_t": self.initial_step, + } + scheduler_args.update(kwargs) + + scheduler = CubicSL(sparsifier, **scheduler_args) + return sparsifier, scheduler + + @staticmethod + def _get_sparsity_levels(sparsifier, precision=32): + r"""Gets the current levels of sparsity in a sparsifier.""" + return [ + round(group["sparsity_level"], precision) for group in sparsifier.groups + ] + + def test_constructor(self): + model = self._make_model() + sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True) + self.assertIs( + scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached" + ) + self.assertEqual( + scheduler._step_count, + 1, + msg="Scheduler is initialized with incorrect step count", + ) + self.assertEqual( + scheduler.base_sl, + self.sorted_sparse_levels, + msg="Scheduler did not store the target sparsity levels correctly", + ) + + # Value before t_0 is 0 + self.assertEqual( + self._get_sparsity_levels(sparsifier), + scheduler._make_sure_a_list(0.0), + msg="Sparsifier is not reset correctly after attaching to the Scheduler", + ) + + # Value before t_0 is s_0 + model = self._make_model() + sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False) + self.assertEqual( + self._get_sparsity_levels(sparsifier), + scheduler._make_sure_a_list(self.initial_sparsity), + msg="Sparsifier is not reset correctly after attaching to the Scheduler", + ) + + def test_step(self): + # For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0 + model = self._make_model() + sparsifier, scheduler = self._make_scheduler( + model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5 + ) + + scheduler.step() + scheduler.step() + self.assertEqual( + scheduler._step_count, + 3, + msg="Scheduler step_count is expected to increment", + ) + # Value before t_0 is supposed to be 0 + self.assertEqual( + self._get_sparsity_levels(sparsifier), + scheduler._make_sure_a_list(0.0), + msg="Scheduler step updating the sparsity level before t_0", + ) + + scheduler.step() # Step = 3 => sparsity = initial_sparsity + self.assertEqual( + self._get_sparsity_levels(sparsifier), + scheduler._make_sure_a_list(self.initial_sparsity), + msg="Sparsifier is not reset to initial sparsity at the first step", + ) + + scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2] + self.assertEqual( + self._get_sparsity_levels(sparsifier, 1), + [0.3, 0.2], + msg="Sparsity level is not set correctly after the first step", + ) + + current_step = scheduler._step_count - scheduler.init_t[0] - 1 + more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step + for _ in range(more_steps_needed): # More steps needed to final sparsity level + scheduler.step() + self.assertEqual( + self._get_sparsity_levels(sparsifier), + self.sorted_sparse_levels, + msg="Sparsity level is not reaching the target level afer delta_t * n steps ", + ) + +if __name__ == "__main__": + unittest.main() diff --git a/test/sparsity/test_sparsifier.py b/test/sparsity/test_sparsifier.py new file mode 100644 index 0000000000..0deeea9ca7 --- /dev/null +++ b/test/sparsity/test_sparsifier.py @@ -0,0 +1,490 @@ +# Owner(s): ["module: unknown"] + +import itertools +import logging +import re +import unittest + +import torch +from torch import nn +from torchao.sparsity.prototype import ( + BaseSparsifier, + FakeSparsity, + NearlyDiagonalSparsifier, + WeightNormSparsifier, +) +from torch.nn.utils.parametrize import is_parametrized +from torch.testing._internal.common_pruning import ( + ImplementedSparsifier, + MockSparseLinear, + SimpleLinear, +) + +from torch.testing._internal.common_utils import TestCase + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + + +class TestBaseSparsifier(TestCase): + def test_constructor(self): + # Cannot instantiate the abstract base + self.assertRaises(TypeError, BaseSparsifier) + # Can instantiate the model with no configs + model = SimpleLinear() + sparsifier = ImplementedSparsifier(test=3) + sparsifier.prepare(model, config=None) + assert len(sparsifier.groups) == 5 + sparsifier.step() + # Can instantiate the model with configs + sparsifier = ImplementedSparsifier(test=3) + sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) + assert len(sparsifier.groups) == 1 + assert sparsifier.groups[0]["tensor_fqn"] == "linear1.weight" + assert "test" in sparsifier.groups[0] + assert sparsifier.groups[0]["test"] == 3 + + def test_prepare_config(self): + model = SimpleLinear() + sparsifier = ImplementedSparsifier(test=3) + # Make sure there are no parametrizations before `prepare` + assert not hasattr(model.seq[0], "parametrizations") + assert not hasattr(model.linear1, "parametrizations") + assert not hasattr(model.linear2, "parametrizations") + sparsifier.prepare( + model, + config=[ + {"tensor_fqn": "seq.0.weight", "test": 42}, + # No 'linear1' to make sure it will be skipped in the sparsification + {"tensor_fqn": "linear2.weight"}, + ], + ) + assert len(sparsifier.groups) == 2 + # Check if default argument is not assigned if explicit + assert sparsifier.groups[0]["tensor_fqn"] == "seq.0.weight" + assert sparsifier.groups[0]["test"] == 42 + # Check if FQN and module are pointing to the same location + assert sparsifier.groups[1]["tensor_fqn"] == "linear2.weight" + assert sparsifier.groups[1]["module"] == model.linear2 + # Check if parameterizations are attached + assert hasattr(model.seq[0], "parametrizations") + assert not hasattr(model.linear1, "parametrizations") + assert hasattr(model.linear2, "parametrizations") + + def test_step(self): + model = SimpleLinear() + sparsifier = ImplementedSparsifier(test=3) + sparsifier.enable_mask_update = True + sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) + sparsifier.step() + assert torch.all(model.linear1.parametrizations.weight[0].mask[0] == 0) + + def test_state_dict(self): + step_count = 3 + model0 = SimpleLinear() + sparsifier0 = ImplementedSparsifier(test=3) + sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}]) + mask = model0.linear1.parametrizations["weight"][0].mask + mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape) + for step in range(step_count): + sparsifier0.step() + state_dict = sparsifier0.state_dict() + + # Check the expected keys in the state_dict + assert "state" in state_dict + assert "step_count" in state_dict["state"]["linear1.weight"] + assert state_dict["state"]["linear1.weight"]["step_count"] == 3 + assert "groups" in state_dict + assert "test" in state_dict["groups"][0] + assert "tensor_fqn" in state_dict["groups"][0] + assert state_dict["groups"][0]["tensor_fqn"] == "linear1.weight" + + # Check loading static_dict creates an equivalent model + model1 = SimpleLinear() + sparsifier1 = ImplementedSparsifier() + sparsifier1.prepare(model1, None) + + assert sparsifier0.state != sparsifier1.state + + # Make sure the masks are different in the beginning + for mg in sparsifier0.groups: + if mg["tensor_fqn"] == "linear1.weight": + mask0 = mg["module"].parametrizations.weight[0].mask + for mg in sparsifier1.groups: + if mg["tensor_fqn"] == "linear1.weight": + mask1 = mg["module"].parametrizations.weight[0].mask + self.assertNotEqual(mask0, mask1) + + sparsifier1.load_state_dict(state_dict) + + # Make sure the states are loaded, and are correct + assert sparsifier0.state == sparsifier1.state + + # Make sure the masks (and all dicts) are the same after loading + assert len(sparsifier0.groups) == len(sparsifier1.groups) + for idx in range(len(sparsifier0.groups)): + mg0 = sparsifier0.groups[idx] + mg1 = sparsifier1.groups[idx] + for key in mg0.keys(): + assert key in mg1 + if key == "module": + # We cannot compare modules as they are different + param0 = mg0[key].parametrizations.weight[0] + param1 = mg1[key].parametrizations.weight[0] + assert hasattr(param0, "mask") + assert hasattr(param1, "mask") + self.assertEqual(param0.__dict__, param1.__dict__) + else: + assert mg0[key] == mg1[key] + + def test_convert(self): + model = SimpleLinear() + sparsifier = ImplementedSparsifier(test=3) + sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) + new_model = sparsifier.convert( + model, mapping={nn.Linear: MockSparseLinear}, inplace=False + ) + + assert isinstance(new_model.linear1, MockSparseLinear) + assert isinstance(new_model.seq[0], nn.Linear) + assert isinstance(new_model.linear2, nn.Linear) + + def test_mask_squash(self): + model = SimpleLinear() + sparsifier = ImplementedSparsifier(test=3) + sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}]) + assert hasattr(model.linear1.parametrizations.weight[0], "mask") + assert is_parametrized(model.linear1, "weight") + assert not is_parametrized(model.seq[0], "weight") + + sparsifier.squash_mask() + assert not is_parametrized(model.seq[0], "weight") + assert not is_parametrized(model.linear1, "weight") + + def test_mask_squash_with_params1(self): + model = SimpleLinear() + sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) + sparsifier.prepare( + model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] + ) + sparsifier.squash_mask( + params_to_keep_per_layer={"linear1": ("foo", "bar"), "seq.0": ("baz",)} + ) + assert not is_parametrized(model.seq[0], "weight") + assert not is_parametrized(model.linear1, "weight") + assert hasattr(model.seq[0], "sparse_params") + assert hasattr(model.linear1, "sparse_params") + assert model.seq[0].sparse_params.get("foo", None) is None + assert model.seq[0].sparse_params.get("bar", None) is None + assert model.seq[0].sparse_params.get("baz", None) == 1 + assert model.linear1.sparse_params.get("foo", None) == 3 + assert model.linear1.sparse_params.get("bar", None) == 2 + assert model.linear1.sparse_params.get("baz", None) is None + + def test_mask_squash_with_params2(self): + model = SimpleLinear() + sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) + sparsifier.prepare( + model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] + ) + sparsifier.squash_mask(params_to_keep=("foo", "bar")) + assert not is_parametrized(model.seq[0], "weight") + assert not is_parametrized(model.linear1, "weight") + assert hasattr(model.seq[0], "sparse_params") + assert hasattr(model.linear1, "sparse_params") + assert model.seq[0].sparse_params.get("foo", None) == 3 + assert model.seq[0].sparse_params.get("bar", None) == 2 + assert model.seq[0].sparse_params.get("baz", None) is None + assert model.linear1.sparse_params.get("foo", None) == 3 + assert model.linear1.sparse_params.get("bar", None) == 2 + assert model.linear1.sparse_params.get("baz", None) is None + + def test_mask_squash_with_params3(self): + model = SimpleLinear() + sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1) + sparsifier.prepare( + model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}] + ) + sparsifier.squash_mask( + params_to_keep=("foo", "bar"), params_to_keep_per_layer={"seq.0": ("baz",)} + ) + assert not is_parametrized(model.seq[0], "weight") + assert not is_parametrized(model.linear1, "weight") + assert hasattr(model.seq[0], "sparse_params") + assert hasattr(model.linear1, "sparse_params") + assert model.seq[0].sparse_params.get("foo", None) == 3 + assert model.seq[0].sparse_params.get("bar", None) == 2 + assert model.seq[0].sparse_params.get("baz", None) == 1 + assert model.linear1.sparse_params.get("foo", None) == 3 + assert model.linear1.sparse_params.get("bar", None) == 2 + assert model.linear1.sparse_params.get("baz", None) is None + + +class TestWeightNormSparsifier(TestCase): + def test_constructor(self): + model = SimpleLinear() + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=None) + for g in sparsifier.groups: + assert isinstance(g["module"], nn.Linear) + # The groups are unordered + assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2") + + def test_step(self): + model = SimpleLinear() + sparsifier = WeightNormSparsifier(sparsity_level=0.5) + sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) + for g in sparsifier.groups: + # Before step + module = g["module"] + assert ( + 1.0 - module.parametrizations["weight"][0].mask.mean() + ) == 0 # checking sparsity level is 0 + sparsifier.enable_mask_update = True + sparsifier.step() + self.assertAlmostEqual( + model.linear1.parametrizations["weight"][0].mask.mean().item(), + 0.5, + places=2, + ) + for g in sparsifier.groups: + # After step + module = g["module"] + assert ( + 1.0 - module.parametrizations["weight"][0].mask.mean() + ) > 0 # checking sparsity level has increased + # Test if the mask collapses to all zeros if the weights are randomized + iters_before_collapse = 1000 + for _ in range(iters_before_collapse): + model.linear1.weight.data = torch.randn(model.linear1.weight.shape) + sparsifier.step() + for g in sparsifier.groups: + # After step + module = g["module"] + assert ( + 1.0 - module.parametrizations["weight"][0].mask.mean() + ) > 0 # checking sparsity level did not collapse + + def test_step_2_of_4(self): + model = SimpleLinear() + sparsifier = WeightNormSparsifier( + sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2 + ) + sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) + sparsifier.step() + # make sure the sparsity level is approximately 50% + mask = model.linear1.parametrizations["weight"][0].mask.to( + torch.float + ) # mean works on float only + self.assertAlmostEqual(mask.mean().item(), 0.5, places=2) + # Make sure each block has exactly 50% zeros + module = sparsifier.groups[0]["module"] + mask = module.parametrizations["weight"][0].mask + for row in mask: + for idx in range(0, len(row), 4): + block = row[idx : idx + 4] + block, _ = block.sort() + assert (block[:2] == 0).all() + assert (block[2:] != 0).all() + + def test_prepare(self): + model = SimpleLinear() + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=None) + for g in sparsifier.groups: + module = g["module"] + # Check mask exists + assert hasattr(module.parametrizations["weight"][0], "mask") + # Check parametrization exists and is correct + assert is_parametrized(module, "weight") + assert type(module.parametrizations.weight[0]) == FakeSparsity + + def test_mask_squash(self): + model = SimpleLinear() + sparsifier = WeightNormSparsifier() + sparsifier.prepare(model, config=None) + sparsifier.squash_mask() + for g in sparsifier.groups: + module = g["module"] + assert not is_parametrized(module, "weight") + assert not hasattr(module, "mask") + + def test_sparsity_levels(self): + sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0] + sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)] + zeros_per_blocks = [0, 1, 2, 3, 4] + + testcases = itertools.tee( + itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks) + ) + # Create a config and model with all the testcases + model = nn.Sequential() + sparsifier = WeightNormSparsifier() + + sparsity_per_layer_config = [] + p = re.compile(r"[-\.\s]") + for sl, sbs, zpb in testcases[0]: + # Make sure the number of zeros is not > values in a block + if zpb > sbs[0] * sbs[1]: + continue + layer_name = f"{sl}_{sbs}_{zpb}" + layer_name = p.sub("_", layer_name) + + layer = nn.Linear(12, 12, bias=False) + layer.weight = nn.Parameter(torch.ones(12, 12)) + model.add_module(layer_name, layer) + config = { + "tensor_fqn": layer_name + ".weight", + "sparsity_level": sl, + "sparse_block_shape": sbs, + "zeros_per_block": zpb, + } + sparsity_per_layer_config.append(config) + + sparsifier.prepare(model, sparsity_per_layer_config) + sparsifier.step() + sparsifier.squash_mask() + model.eval() + + for sl, sbs, zpb in testcases[1]: + if zpb > sbs[0] * sbs[1]: + continue + layer_name = f"{sl}_{sbs}_{zpb}" + layer_name = p.sub("_", layer_name) + layer = getattr(model, layer_name) + + # Level of sparsity is achieved + sparse_mask = (layer.weight == 0).float() + if zpb == 0: + assert sparse_mask.mean() == 0 + else: + # Ratio of individual zeros in the tensor + true_sl = min(max(sl, 0.0), 1.0) + true_sl = true_sl * zpb / sbs[0] / sbs[1] + assert sparse_mask.mean() == true_sl + + +class TestNearlyDiagonalSparsifier(TestCase): + def test_constructor(self): + model = SimpleLinear() + sparsifier = NearlyDiagonalSparsifier(nearliness=1) + sparsifier.prepare(model, config=None) + for g in sparsifier.groups: + assert isinstance(g["module"], nn.Linear) + # The groups are unordered + assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2") + + def test_step(self): + model = SimpleLinear() + sparsifier = NearlyDiagonalSparsifier(nearliness=1) + sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}]) + + for g in sparsifier.groups: + # Before step + module = g["module"] + assert ( + 1.0 - module.parametrizations["weight"][0].mask.mean() + ) == 0 # checking sparsity level is 0 + + sparsifier.enable_mask_update = True + sparsifier.step() + mask = module.parametrizations["weight"][0].mask + height, width = mask.shape + assert torch.all(mask == torch.eye(height, width)) + + for g in sparsifier.groups: + # After step + module = g["module"] + assert ( + 1.0 - module.parametrizations["weight"][0].mask.mean() + ) > 0 # checking sparsity level has increased + + # Test if the mask collapses to all zeros if the weights are randomized + iters_before_collapse = 1000 + for _ in range(iters_before_collapse): + model.linear1.weight.data = torch.randn(model.linear1.weight.shape) + sparsifier.step() + for g in sparsifier.groups: + # After step + module = g["module"] + assert ( + 1.0 - module.parametrizations["weight"][0].mask.mean() + ) > 0 # checking sparsity level did not collapse + + def test_prepare(self): + model = SimpleLinear() + sparsifier = NearlyDiagonalSparsifier(nearliness=1) + sparsifier.prepare(model, config=None) + for g in sparsifier.groups: + module = g["module"] + # Check mask exists + assert hasattr(module.parametrizations["weight"][0], "mask") + # Check parametrization exists and is correct + assert is_parametrized(module, "weight") + assert type(module.parametrizations.weight[0]) == FakeSparsity + + def test_mask_squash(self): + model = SimpleLinear() + sparsifier = NearlyDiagonalSparsifier(nearliness=1) + sparsifier.prepare(model, config=None) + sparsifier.step() + sparsifier.squash_mask() + for g in sparsifier.groups: + module = g["module"] + assert not is_parametrized(module, "weight") + assert not hasattr(module, "mask") + weights = module.weight + height, width = weights.shape + assert torch.all( + weights == torch.eye(height, width) * weights + ) # only diagonal to be present + + def test_sparsity_levels(self): + nearliness_levels = list(range(-1, 100)) + model = nn.Sequential() + + p = re.compile(r"[-\.\s]") + for nearliness in nearliness_levels: + sparsifier = NearlyDiagonalSparsifier(nearliness=1) + layer_name = f"{nearliness}" + layer_name = p.sub("_", layer_name) + + layer = nn.Linear(32, 32, bias=False) + layer.weight = nn.Parameter(torch.ones(32, 32)) + width, height = layer.weight.shape + model.add_module(layer_name, layer) + config = {"tensor_fqn": layer_name + ".weight", "nearliness": nearliness} + + sparsifier.prepare(model, [config]) + # should raise a ValueError when nearliness arg is illegal + if (nearliness > 0 and nearliness % 2 == 0) or ( + nearliness // 2 >= min(width, height) + ): + with self.assertRaises(ValueError): + sparsifier.step() + else: + sparsifier.step() + sparsifier.squash_mask() + model.eval() + + layer = getattr(model, layer_name) + # verify that mask created corresponds to the nearliness + self._verify_nearliness(layer.weight, nearliness) + + # helper function to verify nearliness of a mask + def _verify_nearliness(self, mask: torch.Tensor, nearliness: int): + if nearliness <= 0: + assert torch.all(mask == torch.zeros(mask.shape[0], mask.shape[1])) + else: + height, width = mask.shape + dist_to_diagonal = nearliness // 2 + for row in range(0, height): + for col in range(0, width): + if abs(row - col) <= dist_to_diagonal: + assert mask[row, col] == 1 + else: + assert mask[row, col] == 0 + +if __name__ == "__main__": + unittest.main() diff --git a/test/sparsity/test_sparsity_utils.py b/test/sparsity/test_sparsity_utils.py new file mode 100644 index 0000000000..91d0d2d562 --- /dev/null +++ b/test/sparsity/test_sparsity_utils.py @@ -0,0 +1,151 @@ +import logging +import unittest + +import torch +from torchao.sparsity.prototype.sparsifier.utils import ( + fqn_to_module, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) + +from torch.testing._internal.common_quantization import ( + ConvBnReLUModel, + ConvModel, + FunctionalLinear, + LinearAddModel, + ManualEmbeddingBagLinear, + SingleLayerLinearModel, + TwoLayerLinearModel, +) +from torch.testing._internal.common_utils import TestCase + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +model_list = [ + ConvModel, + SingleLayerLinearModel, + TwoLayerLinearModel, + LinearAddModel, + ConvBnReLUModel, + ManualEmbeddingBagLinear, + FunctionalLinear, +] + + +class TestSparsityUtilFunctions(TestCase): + def test_module_to_fqn(self): + """ + Tests that module_to_fqn works as expected when compared to known good + module.get_submodule(fqn) function + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + fqn = module_to_fqn(model, module) + check_module = model.get_submodule(fqn) + self.assertEqual(module, check_module) + + def test_module_to_fqn_fail(self): + """ + Tests that module_to_fqn returns None when an fqn that doesn't + correspond to a path to a node/tensor is given + """ + for model_class in model_list: + model = model_class() + fqn = module_to_fqn(model, torch.nn.Linear(3, 3)) + self.assertEqual(fqn, None) + + def test_module_to_fqn_root(self): + """ + Tests that module_to_fqn returns '' when model and target module are the same + """ + for model_class in model_list: + model = model_class() + fqn = module_to_fqn(model, model) + self.assertEqual(fqn, "") + + def test_fqn_to_module(self): + """ + Tests that fqn_to_module operates as inverse + of module_to_fqn + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + fqn = module_to_fqn(model, module) + check_module = fqn_to_module(model, fqn) + self.assertEqual(module, check_module) + + def test_fqn_to_module_fail(self): + """ + Tests that fqn_to_module returns None when it tries to + find an fqn of a module outside the model + """ + for model_class in model_list: + model = model_class() + fqn = "foo.bar.baz" + check_module = fqn_to_module(model, fqn) + self.assertEqual(check_module, None) + + def test_fqn_to_module_for_tensors(self): + """ + Tests that fqn_to_module works for tensors, actually all parameters + of the model. This is tested by identifying a module with a tensor, + and generating the tensor_fqn using module_to_fqn on the module + + the name of the tensor. + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + module_fqn = module_to_fqn(model, module) + for tensor_name, tensor in module.named_parameters(recurse=False): + tensor_fqn = ( # string manip to handle tensors on root + module_fqn + ("." if module_fqn != "" else "") + tensor_name + ) + check_tensor = fqn_to_module(model, tensor_fqn) + self.assertEqual(tensor, check_tensor) + + def test_get_arg_info_from_tensor_fqn(self): + """ + Tests that get_arg_info_from_tensor_fqn works for all parameters of the model. + Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and + then compares with known (parent) module and tensor_name as well as module_fqn + from module_to_fqn. + """ + for model_class in model_list: + model = model_class() + list_of_modules = [m for _, m in model.named_modules()] + [model] + for module in list_of_modules: + module_fqn = module_to_fqn(model, module) + for tensor_name, tensor in module.named_parameters(recurse=False): + tensor_fqn = ( + module_fqn + ("." if module_fqn != "" else "") + tensor_name + ) + arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) + self.assertEqual(arg_info["module"], module) + self.assertEqual(arg_info["module_fqn"], module_fqn) + self.assertEqual(arg_info["tensor_name"], tensor_name) + self.assertEqual(arg_info["tensor_fqn"], tensor_fqn) + + def test_get_arg_info_from_tensor_fqn_fail(self): + """ + Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn + inputs. The string outputs still work but the output module is expected to be None. + """ + for model_class in model_list: + model = model_class() + tensor_fqn = "foo.bar.baz" + arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) + self.assertEqual(arg_info["module"], None) + self.assertEqual(arg_info["module_fqn"], "foo.bar") + self.assertEqual(arg_info["tensor_name"], "baz") + self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/sparsity/test_structured_sparsifier.py b/test/sparsity/test_structured_sparsifier.py new file mode 100644 index 0000000000..471e891ee7 --- /dev/null +++ b/test/sparsity/test_structured_sparsifier.py @@ -0,0 +1,1097 @@ +# Owner(s): ["module: unknown"] +import copy +import logging +import random +import unittest + +import torch +from torch import nn +from torchao.sparsity.prototype.pruner import ( + BaseStructuredSparsifier, + FakeStructuredSparsity, + FPGMPruner, + LSTMSaliencyPruner, + SaliencyPruner, +) +from torch.nn.utils import parametrize +from torch.testing._internal.common_pruning import ( + Conv2dActivation, + Conv2dBias, + Conv2dPadBias, + Conv2dPool, + Conv2dPoolFlatten, + Conv2dPoolFlattenFunctional, + LinearActivation, + LinearActivationFunctional, + LinearBias, + LSTMLayerNormLinearModel, + LSTMLinearModel, + rows_are_subset, + SimpleConv2d, + SimpleLinear, +) + +from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase + + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +DEVICES = { + torch.device("cpu"), + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), +} + + +class SimplePruner(BaseStructuredSparsifier): + def update_mask(self, module, tensor_name, **kwargs): + getattr(module.parametrizations, tensor_name)[0].mask[1] = False + + +class ImplementedPruner(BaseStructuredSparsifier): + def update_mask(self, module, tensor_name, **kwargs): + """Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning""" + num_rows = len(module.parametrizations[tensor_name][0].mask) + prune = random.sample(list(range(num_rows)), num_rows // 3) + module.parametrizations[tensor_name][0].mask[prune] = False + + +class BottomHalfLSTMPruner(BaseStructuredSparsifier): + """ + Pruner that will remove the bottom half of the rows. + This is primarily meant for testing purposes + """ + + def update_mask(self, module, tensor_name, **kwargs): + for p in getattr(module.parametrizations, tensor_name): + if isinstance(p, FakeStructuredSparsity): + mask = p.mask + masks = torch.split(mask, len(mask) // 4) + for small in masks: + num = len(small) + small[num // 2 :] = False + new_mask = torch.cat(masks) + mask.data = new_mask.data + + +class TestSaliencyPruner(TestCase): + def test_saliency_pruner_update_mask(self): + """Test that we prune out the row with the lowest saliency (first row)""" + model = SimpleLinear() + with torch.no_grad(): + model.linear1.weight = nn.Parameter( + torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]) + ) + pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}] + pruner = SaliencyPruner({}) + + pruner.prepare(model, pruning_config) + pruner.enable_mask_update = True + pruner.step() + pruned_model = pruner.prune() + + expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]]) + pruned = pruned_model.linear1.weight + + assert expected.shape == pruned.shape + assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() + + def test_lstm_saliency_pruner_update_mask(self): + model = LSTMLinearModel( + input_dim=2, + hidden_dim=2, + output_dim=2, + num_layers=1, + ) + + manual_weights = torch.Tensor( + [[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]] + ) + + with torch.no_grad(): + model.lstm.weight_ih_l0 = nn.Parameter(manual_weights) + model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights)) + model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0]) + model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0]) + + config = [ + {"tensor_fqn": "lstm.weight_ih_l0"}, + {"tensor_fqn": "lstm.weight_hh_l0"}, + ] + lstm_input = torch.ones((1, 2)) + fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5}) + fx_pruner.prepare(model, config) + fx_pruner.enable_mask_update = True + fx_pruner.step() + + model.eval() + pruned_model = fx_pruner.prune() + pruned_model.eval() + + # make sure both models run + model(lstm_input) + pruned_model(lstm_input) + + # make sure lowest saliency rows are pruned + expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]]) + pruned = model.lstm.weight_ih_l0 + assert expected.shape == pruned.shape + assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() + + expected = torch.Tensor([[2], [2], [-2], [-2]]) + pruned = model.lstm.weight_hh_l0 + assert expected.shape == pruned.shape + assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() + + expected = torch.Tensor([2, 2, -2, -2]) + for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]: + assert expected.shape == pruned.shape + assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all() + + +class TestBaseStructuredSparsifier(TestCase): + def _check_pruner_prepared(self, model, pruner, device): + for config in pruner.groups: + module = config["module"] + assert module.weight.device.type == device.type + # Check mask exists + assert config["tensor_fqn"] in pruner.state + # Check parametrization exists and is correct + assert parametrize.is_parametrized(module) + assert hasattr(module, "parametrizations") + # Assume that this is the 1st/only parametrization + assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity + + def _check_pruner_valid_before_step(self, model, pruner, device): + for config in pruner.groups: + modules = [] + if type(config["module"]) is tuple: + modules.extend(config["module"]) + else: + module = config["module"] + modules.append(module) + for module in modules: + assert module.weight.device.type == device.type + assert module.parametrizations.weight[0].mask.dtype == torch.bool + + def _check_pruner_valid_after_step(self, model, pruner, mask, device): + for config in pruner.groups: + modules = [] + if type(config["module"]) is tuple: + modules.extend(config["module"]) + else: + module = config["module"] + modules.append(module) + for module in modules: + assert module.weight.device.type == device.type + total = module.parametrizations.weight[0].mask.numel() + assert ( + module.parametrizations.weight[0].mask.count_nonzero() + == total - mask + ) + + def _test_constructor_on_device(self, model, device): + self.assertRaisesRegex( + TypeError, + "BaseStructuredSparsifier.*update_mask", + BaseStructuredSparsifier, + ) + model1 = copy.deepcopy(model).to(device) + pruner = SimplePruner(None) + pruner.prepare(model1, None) + pruner.enable_mask_update = True + for g in pruner.groups: + module = g["module"] + assert module.weight.device.type == device.type + assert len(pruner.groups) == 5 + pruner.step() + # Can instantiate the model with configs + model2 = copy.deepcopy(model).to(device) + pruner = SimplePruner({"test": 3}) + pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}]) + assert len(pruner.groups) == 1 + assert pruner.groups[0]["module_fqn"] == "seq.0" + assert "test" in pruner.groups[0] + assert pruner.groups[0]["test"] == 3 + + def test_constructor(self): + model = SimpleLinear() + for device in DEVICES: + self._test_constructor_on_device(model, torch.device(device)) + + def _test_prepare_linear_on_device(self, model, device): + model = copy.deepcopy(model).to(device) + x = torch.ones(128, 7, device=device) + pruner = SimplePruner(None) + pruner.prepare(model, None) + self._check_pruner_prepared(model, pruner, device) + assert model(x).shape == (128, 10) + + def test_prepare_linear(self): + models = [ + SimpleLinear(), + LinearBias(), + LinearActivation(), + LinearActivationFunctional(), + ] # without and with bias + for device in DEVICES: + for model in models: + self._test_prepare_linear_on_device(model, torch.device(device)) + + def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device): + x = torch.ones((1, 1, 28, 28), device=device) + pruner = SimplePruner(None) + pruner.prepare(model, config) + self._check_pruner_prepared(model, pruner, device) + assert model(x).shape == expected_shape + + def test_prepare_conv2d(self): + models = [ + SimpleConv2d(), + Conv2dBias(), + Conv2dActivation(), + Conv2dPadBias(), + Conv2dPool(), + ] + shapes = [ + (1, 52, 20, 20), + (1, 52, 18, 18), + (1, 52, 18, 18), + (1, 52, 24, 24), + (1, 52, 3, 3), + ] + configs = [None, None, None, None, None] + for device in DEVICES: + for model, shape, config in zip(models, shapes, configs): + model = model.to(device) + self._test_prepare_conv2d_on_device( + model, shape, config, torch.device(device) + ) + + def _test_step_linear_on_device(self, model, device): + model = model.to(device) + x = torch.ones(7, 7, device=device) + pruner = SimplePruner(None) + pruner.prepare(model, None) + pruner.enable_mask_update = True + self._check_pruner_valid_before_step(model, pruner, device) + pruner.step() + self._check_pruner_valid_after_step(model, pruner, 1, device) + + def test_step_linear(self): + models = [ + SimpleLinear(), + LinearBias(), + LinearActivation(), + LinearActivationFunctional(), + ] + for device in DEVICES: + for model in models: + self._test_step_linear_on_device(model, torch.device(device)) + + def _test_step_conv2d_on_device(self, model, expected_shape, config, device): + model = model.to(device) + x = torch.ones((1, 1, 28, 28), device=device) + pruner = SimplePruner(None) + pruner.prepare(model, config) + pruner.enable_mask_update = True + self._check_pruner_valid_before_step(model, pruner, device) + pruner.step() + self._check_pruner_valid_after_step(model, pruner, 1, device) + assert model(x).shape == expected_shape + + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") + def test_step_conv2d(self): + models = [ + SimpleConv2d(), + Conv2dBias(), + Conv2dActivation(), + Conv2dPadBias(), + Conv2dPool(), + ] + shapes = [ + (1, 52, 20, 20), + (1, 52, 18, 18), + (1, 52, 18, 18), + (1, 52, 24, 24), + (1, 52, 3, 3), + ] + configs = [None, None, None, None, None] + for device in DEVICES: + for model, shape, config in zip(models, shapes, configs): + self._test_step_conv2d_on_device( + model, shape, config, torch.device(device) + ) + + def _check_pruner_pruned(self, model, pruner, device): + for config in pruner.groups: + module = config["module"] + assert not hasattr(module, "parametrizations") + assert not hasattr(module, "mask") + + def _test_linear_on_device( + self, model, config, expected_shape, device, also_prune_bias + ): + model = model.to(device) + model.eval() + num_original_params = sum(p.numel() for p in model.parameters()) + x = torch.ones(128, 7, device=device) + + pruner = ImplementedPruner({"prune_bias": also_prune_bias}) + pruner.prepare(model, config) + pruner.enable_mask_update = True + pruner.step() + + y_expected = model(x) + + assert y_expected.shape == (128, 10) + self._check_pruner_prepared(model, pruner, device) + + # Pruning step + pruned = pruner.prune() + y_pruned = pruned(x) + num_pruned_params = sum(p.numel() for p in pruned.parameters()) + + assert y_pruned.shape == expected_shape + self._check_pruner_pruned(model, pruner, device) + if y_pruned.shape == y_expected.shape: + assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all() + assert num_pruned_params < num_original_params + + def test_prune_linear_linear(self): + r"""test pruning linear-> linear modules""" + configs, shapes = [], [] + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((128, 10)) + + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "linear1.weight"}, + ] + ) + shapes.append((128, 10)) + + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((128, 10)) + for device in DEVICES: + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_linear_on_device( + SimpleLinear(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_linear_bias_linear(self): + # linear(bias) -> linear(no bias) + configs, shapes = [], [] + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + ] + ) + shapes.append((128, 10)) + + # linear(bias) -> linear(bias) + configs.append( + [ + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "seq.3.weight"}, + ] + ) + shapes.append((128, 10)) + + # linear(no bias) -> linear(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((128, 10)) + + for device in DEVICES: + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_linear_on_device( + LinearBias(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_linear_activation_linear(self): + config = [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "seq.4.weight"}, + {"tensor_fqn": "linear1.weight"}, + ] + shape = (128, 10) + + for device in DEVICES: + for also_prune_bias in [True, False]: + # test version with nn.Modules + self._test_linear_on_device( + LinearActivation(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + # test functional version + self._test_linear_on_device( + LinearActivationFunctional(), + config, + shape, + torch.device(device), + also_prune_bias, + ) + + def _test_conv2d_on_device( + self, model, config, x, expected_shape, device, also_prune_bias + ): + model = model.to(device) + num_original_params = sum(p.numel() for p in model.parameters()) + model.eval() + + pruner = ImplementedPruner({"prune_bias": also_prune_bias}) + pruner.prepare(model, config) + pruner.enable_mask_update = True + pruner.step() + + y_expected = model(x) + assert y_expected.shape == expected_shape + + self._check_pruner_prepared(model, pruner, device) + + # Fusion step + pruned = pruner.prune() + y_pruned = pruned(x) + num_pruned_params = sum(p.numel() for p in pruned.parameters()) + + assert y_pruned.shape == expected_shape + self._check_pruner_pruned(model, pruner, device) + if y_pruned.shape == y_expected.shape: + # TODO This rtol is a little high, need to double check if something specific is causing this to fail + assert torch.isclose( + y_expected, + y_pruned, + rtol=1e-3, + atol=1e-3, + ).all(), f"fail for {type(model)}" + # only time this should be equal is when all layers have padding and we can't prune + assert num_pruned_params <= num_original_params + + def test_prune_conv2d_conv2d(self): + configs, shapes = [], [] + # all within sequential blocks + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + ] + ) + shapes.append((1, 52, 20, 20)) + # prune across sequential blocks + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + ] + ) + shapes.append((1, 52, 20, 20)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + SimpleConv2d(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_bias_conv2d(self): + # Conv2d with Bias and no Activation + configs, shapes = [], [] + # conv2d(bias) -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(no bias) -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(bias) -> conv2d(no bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.1.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + Conv2dBias(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_activation_conv2d(self): + # Conv2d with Activation and no Bias + configs, shapes = [], [] + + # conv2d(no bias) -> activation -> conv2d(no bias) + configs.append( + [ + {"tensor_fqn": "seq.4.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(bias) -> activation -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(bias) -> activation -> conv2d(no bias) + configs.append( + [ + {"tensor_fqn": "seq.2.weight"}, + {"tensor_fqn": "seq.4.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + # conv2d(no bias) -> activation -> conv2d(bias) + configs.append( + [ + {"tensor_fqn": "conv2d1.weight"}, + ] + ) + shapes.append((1, 52, 18, 18)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + Conv2dActivation(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_padding_conv2d(self): + # Conv2d with Padded layers after Bias layers + configs, shapes = [], [] + + # conv(padded, bias) -> conv(padded, bias) + configs.append( + [ + {"tensor_fqn": "seq.4.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + + # conv(no bias, no pad) -> conv(padded, bias) + configs.append( + [ + {"tensor_fqn": "seq.2.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + + # conv(padded, bias) -> conv ( no bias ,no pad) + configs.append( + [ + {"tensor_fqn": "seq.0.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + # conv(pad, bias) -> conv(no pad, bias) + configs.append( + [ + {"tensor_fqn": "seq.6.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + # conv(no pad, bias) -> conv(pad, bias) + configs.append( + [ + {"tensor_fqn": "seq.8.weight"}, + ] + ) + shapes.append((1, 52, 24, 24)) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + for config, shape in zip(configs, shapes): + self._test_conv2d_on_device( + Conv2dPadBias(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_conv2d_pool_conv2d(self): + # Conv2d with Pooling layers + config = [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.3.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + {"tensor_fqn": "conv2d2.weight"}, + ] + shape = (1, 52, 3, 3) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + self._test_conv2d_on_device( + Conv2dPool(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + @skipIfTorchDynamo("TorchDynamo fails with unknown reason") + def test_complex_conv2d(self): + """Test fusion for models that contain Conv2d & Linear modules. + Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add""" + config = [ + {"tensor_fqn": "seq.0.weight"}, + {"tensor_fqn": "seq.3.weight"}, + {"tensor_fqn": "conv2d1.weight"}, + {"tensor_fqn": "conv2d2.weight"}, + ] + shape = (1, 13) + + for device in DEVICES: + x = torch.ones((1, 1, 28, 28), device=device) + for also_prune_bias in [True, False]: + self._test_conv2d_on_device( + Conv2dPoolFlattenFunctional(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + self._test_conv2d_on_device( + Conv2dPoolFlatten(), + config, + x, + shape, + torch.device(device), + also_prune_bias, + ) + + def test_prune_lstm_linear_multiple_layer(self): + """ + Test fusion support for LSTM(multi-layer) -> Linear + """ + model = LSTMLinearModel( + input_dim=8, + hidden_dim=8, + output_dim=8, + num_layers=2, + ) + + config = [ + {"tensor_fqn": "lstm.weight_ih_l0"}, + {"tensor_fqn": "lstm.weight_hh_l0"}, + {"tensor_fqn": "lstm.weight_ih_l1"}, + {"tensor_fqn": "lstm.weight_hh_l1"}, + ] + + lstm_input = torch.ones((1, 8)) + fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) + fx_pruner.prepare(model, config) + + fx_pruner.enable_mask_update = True + fx_pruner.step() + + model.eval() + _, _ = model(lstm_input) + pruned_model = fx_pruner.prune() + pruned_model.eval() + _, _ = pruned_model(lstm_input) + + expected_params = dict(model.named_parameters()) + for name, param in model.named_parameters(): + assert name in expected_params + # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics + # Instead we check that the weights of the new LSTM are a subset of the weights of + # the old LSTM + assert rows_are_subset(param, expected_params[name]) + del expected_params[name] + + # assert we haven't deleted any keys + assert len(expected_params) == 0 + + def test_prune_lstm_linear_single_layer(self): + """ + Test fusion support for LSTM (single-layer) -> Linear + """ + model = LSTMLinearModel( + input_dim=8, + hidden_dim=8, + output_dim=8, + num_layers=1, + ) + + config = [ + {"tensor_fqn": "lstm.weight_ih_l0"}, + {"tensor_fqn": "lstm.weight_hh_l0"}, + ] + + lstm_input = torch.ones((1, 8)) + fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) + fx_pruner.prepare(model, config) + fx_pruner.enable_mask_update = True + fx_pruner.step() + model.eval() + + out_expected, lstm_out_expected = model(lstm_input) + pruned_model = fx_pruner.prune() + pruned_model.eval() + out_pruned, lstm_out_pruned = pruned_model(lstm_input) + r, c = lstm_out_expected.size() + + # We cannot check that y_expected == y_pruned as usual because + # zeros vs. missing elements yield different numerical results. + # Instead that we check that the pruned elements are the first half of the results + # since we are using a BottomHalfLSTMPruner + assert torch.isclose( + lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07 + ).all() + # also check that output of linear is the same shape, this means we've resized + # linear columns correctly. + assert out_expected.shape == out_pruned.shape + + def test_prune_lstm_layernorm_linear_multiple_layer(self): + """ + Test fusion support for LSTM(multi-layer) -> Linear + """ + model = LSTMLayerNormLinearModel( + input_dim=8, + output_dim=8, + hidden_dim=8, + num_layers=2, + ) + + config = [ + {"tensor_fqn": "lstm.weight_ih_l0"}, + {"tensor_fqn": "lstm.weight_hh_l0"}, + {"tensor_fqn": "lstm.weight_ih_l1"}, + {"tensor_fqn": "lstm.weight_hh_l1"}, + ] + + lstm_input = torch.ones((1, 8)) + fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) + fx_pruner.prepare(model, config) + + fx_pruner.enable_mask_update = True + fx_pruner.step() + + model.eval() + _, _ = model(lstm_input) + pruned_model = fx_pruner.prune() + pruned_model.eval() + _, _ = pruned_model(lstm_input) + + expected_params = dict(model.named_parameters()) + for name, param in model.named_parameters(): + assert name in expected_params + # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics + # Instead we check that the weights of the new LSTM are a subset of the weights of + # the old LSTM + assert rows_are_subset(param, expected_params[name]) + del expected_params[name] + + # assert we haven't deleted any keys + assert len(expected_params) == 0 + + def test_prune_lstm_layernorm_linear_single_layer(self): + """ + Test fusion support for LSTM (single-layer) -> Linear + """ + model = LSTMLinearModel( + input_dim=8, + hidden_dim=8, + output_dim=8, + num_layers=1, + ) + + config = [ + {"tensor_fqn": "lstm.weight_ih_l0"}, + {"tensor_fqn": "lstm.weight_hh_l0"}, + ] + + lstm_input = torch.ones((1, 8)) + fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5}) + fx_pruner.prepare(model, config) + fx_pruner.enable_mask_update = True + fx_pruner.step() + model.eval() + + out_expected, lstm_out_expected = model(lstm_input) + pruned_model = fx_pruner.prune() + pruned_model.eval() + out_pruned, lstm_out_pruned = pruned_model(lstm_input) + r, c = lstm_out_expected.size() + + # We cannot check that y_expected == y_pruned as usual because + # zeros vs. missing elements yield different numerical results. + # Instead that we check that the pruned elements are the first half of the results + # since we are using a BottomHalfLSTMPruner + assert torch.isclose( + lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07 + ).all() + # also check that output of linear is the same shape, this means we've resized + # linear columns correctly. + assert out_expected.shape == out_pruned.shape + + +class TestFPGMPruner(TestCase): + """ + Test case for the implementation of paper: + `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. + """ + + class SimpleConvFPGM(nn.Module): + def __init__(self): + super().__init__() + self.conv2d1 = nn.Conv2d( + in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False + ) + # Manually set the filter weights for demonstration purposes + """ + Three filters' weight are manually set to values 3.0, 2.0, and 0.1. + Different from the norm-based decision that prunes filter with value 0.1, + FPGM will prune the one with value 2.0. + """ + weights = torch.tensor([3.0, 2.0, 0.1]) # Weight weights for each filter + weights = weights[:, None, None, None] # broadcasting + self.conv2d1.weight.data.copy_( + torch.ones(self.conv2d1.weight.shape) * weights + ) + + # Second Convolutional Layer + self.conv2d2 = nn.Conv2d( + in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False + ) + weights = torch.tensor([6.0, 7.0, 0.4, 0.5]) + weights = weights[:, None, None, None] + self.conv2d2.weight.data.copy_( + torch.ones(self.conv2d2.weight.shape) * weights + ) + + def forward(self, x): + x = self.conv2d1(x) + x = self.conv2d2(x) + return x + + def test_compute_distance(self, device="cpu"): + """Test the distance computation function""" + model = TestFPGMPruner.SimpleConvFPGM().to(device) + pruner = FPGMPruner(0.3) + dist_conv1 = pruner._compute_distance(model.conv2d1.weight) + + # compute the distance matrix using torch.cdist + flattened_filters = torch.Tensor( + [ + [ + 3.0000, + 3.0000, + 3.0000, + 3.0000, + 3.0000, + 3.0000, + 3.0000, + 3.0000, + 3.0000, + ], + [ + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + 2.0000, + ], + [ + 0.1000, + 0.1000, + 0.1000, + 0.1000, + 0.1000, + 0.1000, + 0.1000, + 0.1000, + 0.1000, + ], + ] + ) + + """ + Expected distance matrix should have the following values: + [0.0000, 3.0000, 8.7000], + [3.0000, 0.0000, 5.7000], + [8.7000, 5.7000, 0.0000], + the distance should therefore be: + [11.7000, 8.7000, 14.4000] + """ + expected_dist_matrix_conv1 = torch.cdist( + flattened_filters, flattened_filters, p=2 + ) + expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1) + assert torch.isclose( + dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07 + ).all() + + def _test_update_mask_on_single_layer(self, expected_conv1, device): + """Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value""" + # test pruning with one layer of conv2d + model = TestFPGMPruner.SimpleConvFPGM().to(device) + x = torch.ones((1, 1, 32, 32), device=device) + pruner = FPGMPruner(0.3) + config = [{"tensor_fqn": "conv2d1.weight"}] + pruner.prepare(model, config) + pruner.enable_mask_update = True + pruner.step() + assert ( + pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item() + is not False + ), "do not prune the least-norm filter" + + # fusion step + pruned_model = pruner.prune() + + pruned_y = pruned_model(x) + # assert shapes + expected_conv1 = expected_conv1.to(device) + assert pruned_y.shape == (1, 4, 32, 32) + assert pruned_model.conv2d1.weight.shape == expected_conv1.shape + assert pruned_model.conv2d2.weight.shape == ( + 4, + 2, + 3, + 3, + ), "conv2d2 should have input channel pruned" + # assert value + assert torch.isclose( + pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07 + ).all() + + def _test_update_mask_on_multiple_layer( + self, expected_conv1, expected_conv2, device + ): + # the second setting + model = TestFPGMPruner.SimpleConvFPGM().to(device) + x = torch.ones((1, 1, 32, 32), device=device) + pruner = FPGMPruner(0.3) + config = [ + {"tensor_fqn": "conv2d1.weight"}, + {"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}, + ] + pruner.prepare(model, config) + pruner.enable_mask_update = True + pruner.step() + # Get the masks for the two least-norm filters + mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1] + mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2] + # Check if either of the least-norm filters is not pruned + assert ( + mask1.item() is not False or mask2.item() is not False + ), "Do not prune all least-norm filters" + + # fusion step + pruned_model = pruner.prune() + pruned_y = pruned_model(x) + # assert shapes + expected_conv1 = expected_conv1.to(device) + expected_conv2 = expected_conv2.to(device) + assert pruned_y.shape == (1, 2, 32, 32) + assert pruned_model.conv2d1.weight.shape == expected_conv1.shape + assert pruned_model.conv2d2.weight.shape == expected_conv2.shape + # assert values + assert torch.isclose( + pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07 + ).all() + assert torch.isclose( + pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07 + ).all() + + def test_update_mask(self): + weights = torch.tensor([3.0, 0.1]) + expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None] + + weights = torch.tensor([7.0, 0.4]) + expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None] + + for device in DEVICES: + self._test_update_mask_on_single_layer(expected_conv1, device) + self._test_update_mask_on_multiple_layer( + expected_conv1, expected_conv2, device + ) + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/sparsity/prototype/__init__.py b/torchao/sparsity/prototype/__init__.py new file mode 100644 index 0000000000..350a310501 --- /dev/null +++ b/torchao/sparsity/prototype/__init__.py @@ -0,0 +1,15 @@ +# Sparsifier +from torchao.sparsity.prototype.sparsifier.base_sparsifier import BaseSparsifier +from torchao.sparsity.prototype.sparsifier.weight_norm_sparsifier import WeightNormSparsifier +from torchao.sparsity.prototype.sparsifier.nearly_diagonal_sparsifier import NearlyDiagonalSparsifier + +# Scheduler +from torchao.sparsity.prototype.scheduler.base_scheduler import BaseScheduler +from torchao.sparsity.prototype.scheduler.lambda_scheduler import LambdaSL +from torchao.sparsity.prototype.scheduler.cubic_scheduler import CubicSL + +# Parametrizations +from torchao.sparsity.prototype.sparsifier.utils import FakeSparsity +from torchao.sparsity.prototype.sparsifier.utils import module_to_fqn +from torchao.sparsity.prototype.sparsifier.utils import fqn_to_module +from torchao.sparsity.prototype.sparsifier.utils import get_arg_info_from_tensor_fqn diff --git a/torchao/sparsity/prototype/pruner/FPGM_pruner.py b/torchao/sparsity/prototype/pruner/FPGM_pruner.py new file mode 100644 index 0000000000..d8c3d20052 --- /dev/null +++ b/torchao/sparsity/prototype/pruner/FPGM_pruner.py @@ -0,0 +1,93 @@ +from typing import Callable, Optional, Union + +import torch + +from .base_structured_sparsifier import BaseStructuredSparsifier + +__all__ = ["FPGMPruner"] + + +class FPGMPruner(BaseStructuredSparsifier): + r"""Filter Pruning via Geometric Median (FPGM) Structured Pruner + This sparsifier prune fliter (row) in a tensor according to distances among filters according to + `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration `_. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of filters (rows) that are zeroed-out. + 2. `dist` defines the distance measurement type. Default: 3 (L2 distance). + Available options are: [1, 2, (custom callable distance function)]. + + Note:: + Inputs should be a 4D convolutional tensor of shape (N, C, H, W). + - N: output channels size + - C: input channels size + - H: height of kernel + - W: width of kernel + """ + + def __init__( + self, sparsity_level: float = 0.5, dist: Optional[Union[Callable, int]] = None + ): + defaults = { + "sparsity_level": sparsity_level, + } + + if dist is None: + dist = 2 + + if callable(dist): + self.dist_fn = dist + elif dist == 1: + self.dist_fn = lambda x: torch.cdist(x, x, p=1) + elif dist == 2: + self.dist_fn = lambda x: torch.cdist(x, x, p=2) + else: + raise NotImplementedError("Distance function is not yet implemented.") + super().__init__(defaults=defaults) + + def _compute_distance(self, t): + r"""Compute distance across all entries in tensor `t` along all dimension + except for the one identified by dim. + Args: + t (torch.Tensor): tensor representing the parameter to prune + Returns: + distance (torch.Tensor): distance computed across filtters + """ + dim = 0 # prune filter (row) + + size = t.size(dim) + slc = [slice(None)] * t.dim() + + # flatten the tensor along the dimension + t_flatten = [ + t[tuple(slc[:dim] + [slice(i, i + 1)] + slc[dim + 1 :])].reshape(-1) + for i in range(size) + ] + t_flatten = torch.stack(t_flatten) + + # distance measurement + dist_matrix = self.dist_fn(t_flatten) + + # more similar with other filter indicates large in the sum of row + distance = torch.sum(torch.abs(dist_matrix), 1) + + return distance + + def update_mask(self, module, tensor_name, sparsity_level, **kwargs): + tensor_weight = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + if sparsity_level <= 0: + mask.data = torch.ones_like(mask).bool() + elif sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask).bool() + else: + distance = self._compute_distance(tensor_weight) + + tensor_size = tensor_weight.shape[0] # prune filter (row) + nparams_toprune = round(sparsity_level * tensor_size) + nparams_toprune = min( + max(nparams_toprune, 0), tensor_size + ) # clamp to [0, tensor_size] + topk = torch.topk(distance, k=nparams_toprune, largest=False) + mask[topk.indices] = False diff --git a/torchao/sparsity/prototype/pruner/README.md b/torchao/sparsity/prototype/pruner/README.md new file mode 100644 index 0000000000..026fd33b28 --- /dev/null +++ b/torchao/sparsity/prototype/pruner/README.md @@ -0,0 +1,251 @@ +# Structured Pruning + +## Intro / Motivation + +**Pruning** is the technique of removing parameters from a model to reduce the computational cost. The goal of pruning is to improve the performance of the model while maintaining it's accuracy. + +### Unstructured vs. Structured Pruning +One way to do this is to consider each parameter individually. This gives us the greatest granularity when pruning and is called **unstructured pruning**. + +For example, consider a simple linear regression model that is parametrized by a weight tensor W. + +``` +W = [[1 2 3] + [4 5 6] + [7 1 9]] +``` + +We can prune the lowest absolute value elements in W in order to preserve as much information as possible. +Below we've removed three parameters from W. + +``` +W_pruned = [[0 0 3] + [4 5 6] + [7 0 9]] +``` + +Unfortunately, zeroing out parameters does not offer a speed-up to the model out of the box. We need custom sparse kernels that are designed to take advantage of sparsity to speed up computation. For more information about unstructured pruning check out our tutorials [here](). + +However, if we zero out a row of parameters at a time instead of a single parameter, we can speed up computation by resizing the weight matrix. This is called **structured pruning** and is what this folder implements. + +``` +W_pruned = [[0 0 0] = [[4, 5, 6], + [4 5 6] [7, 1, 9]] + [7 1 9]] + +``` +### Weight Resizing + +However, since the pruned weight tensor has a different shape than the original weight tensor, subsequent operations will cause an error due to this shape mismatch. We need to remove both the weights of the original weight tensor and the columns of subsequent tensors that correspond to the pruned rows. + +You can see an example of this below for a model containing two linear layers, one parametrized by W and another by U + +![](./images/prune_5.png) + +By removing a row from U and a column from W, we can avoid a shape mismatch. + +![](./images/prune_6.png) + + +One benefit of **structured pruning** is that it uses the same dense kernels that the original model uses, and does not rely on custom sparse kernel like **unstructured pruning**. +However, structured pruning degrades accuracy more than unstructured pruning because of the lack of granularity, so it is not always the right choice. + +Generally the structured pruning process looks something like this: +1. Define what layers in the model you want to structured prune. +2. Evaluate the importance of each row in each layer in the model. +3. Remove rows by resizing the weight matrices of each layer +4. Stop if target sparsity level is met. + +The accuracy degradation of pruning can be quite large initially. Once we are satisfied with our pruned tensor, we usually retrain the model after pruning in order to restore some of this accuracy loss. + +## Quickstart Guide + +**Your model must be FX symbolically traceable**. + +You can test this with the following bit of code: + +```python +from torch.fx import symbolic_trace +model = MyModel() +symbolic_trace(model) +``` + +Using `torch.fx` we can get a compute graph of our model. Each operation (add, multiply, ReLU) is a node in the graph, and the order of operations is defined by the edges of the graph. + +Structured pruning works by traversing this graph and looking for specific **patterns**, which are just a specific sequence of operations. + +Each pattern is tied to a pruning function, which is responsible for structured pruning the graph nodes that match the pattern. + +The above [example](#weight-resizing) of two linear layers would match against a `(nn.Linear, nn.Linear)` pattern. This is how we identify the rows to remove and the columns of the subsequent layer. + +Structured pruning also works on other patterns other than two adjacent Linear layers, + +- linear -> linear +- linear -> activation -> linear +- conv2d -> conv2d +- conv2d -> activation -> conv2d +- conv2d -> activation -> pool -> conv2d +- conv2d -> pool -> activation -> conv2d +- conv2d -> adaptive pool -> flatten -> linear + +A complete set of the patterns we support can be found [here](https://github.com/pytorch/pytorch/blob/master/torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py#L85). + +If you are looking to prune a currently unsupported pattern, you can do this by modifying the pattern dict that we provide to the pruner, see [here](#writing-custom-patterns-and-pruning-functions-for-structured-pruning). Feel free to open a PR to add in new patterns. + + +Here is an example script that will prune away 50% of the rows for all the linear layers in the model, based on the saliency of each row. +```python +from torch.ao.pruning._experimental.pruner import SaliencyPruner + +# Define model +class Model(nn.Module): + def __init__(self): + super().__init__() + self.seq = nn.Sequential( + nn.Linear(700, 500, bias=True), + nn.ReLU(), + nn.Linear(500, 800, bias=False), + nn.ReLU(), + nn.Linear(800, 600, bias=True), + nn.ReLU(), + ) + self.linear = nn.Linear(600, 4, bias=False) + + def forward(self, x): + x = self.seq(x) + x = self.linear(x) + return x + +# Define pruning_config, which specifies which tensors you wish to prune. +# The SaliencyPruner also needs a sparsity_level parameter to specify what % of rows to prune. +pruning_config = [ + {"tensor_fqn": "seq.0.weight", "sparsity_level": 0.5}, + {"tensor_fqn": "seq.2.weight", "sparsity_level": 0.5}, + {"tensor_fqn": "seq.4.weight", "sparsity_level": 0.5}, + {"tensor_fqn": "linear.weight", "sparsity_level": 0.5}, +] + +original = Model() +# define defaults +# for structured pruning, we also prune biases by default. +defaults = {"prune_bias": True} +# any configs passed in here are defaults that are propagated +# Your selection criteria is decided by which pruner you use +pruner = SaliencyPruner(defaults, patterns=patterns) + +# Next we call `prepare`, which will attach `FakeStructuredSparsity` parameterizations +# to the tensors specified in the config. These parameterizations will zero out +# the appropriate weights in order to make the model behave as if it has been pruned. +pruner.prepare(original, sparse_config) + +# take one pruning step. This will update the masks +pruner.enable_mask_update = True +pruner.step() + +# pruner.prune() will find patterns and apply that patterns pruning function to it's matching nodes. +# The output of pruner.prune() is a model with resized weights and the masks / parametrizations removed. +pruned_model = pruner.prune() +``` +Afterwards, by printing the name and size of each parameter in our model, we can see that it has been pruned. + +``` +# original model +Parameter name | Shape | # of elements +--------------------|-----------------|--------------- +seq.0.weight | 500, 700 | 350000 +seq.0.bias | 500 | 500 +seq.2.weight | 800, 500 | 400000 +seq.4.weight | 600, 800 | 480000 +seq.4.bias | 600 | 600 +linear.weight | 4, 600 | 2400 +=== Total Number of Parameters: 1233500 === +``` +``` +# pruned model +Parameter name | Shape | # of elements +--------------------|-----------------|--------------- +seq.0.weight | 250, 700 | 175000 +seq.0.bias | 250 | 250 +seq.2.weight | 400, 250 | 100000 +seq.4.weight | 300, 400 | 120000 +seq.4.bias | 300 | 300 +linear.weight | 2, 300 | 600 +=== Total Number of Parameters: 396150 === +``` + +Although we pruned 50% of the rows, the total number of parameters is 25% of the original model. + +Since we remove both the rows of a weight tensor and the columns of the subsequent tensor. The total number of parameters is roughly (1-0.5)* (1-0.5) = 0.25 of the original number of parameters. + +## Advanced Tutorial + +### Pruning Config + +To specify the layers to prune we just need the fully qualified name (FQN) of the tensor you are looking to prune in the module. +You can get the FQN of a tensor by printing out `model.named_parameters()`. + +To prune multiple layers, we just append entries to the pruning config. +**tensor_fqn** is the only required key in the pruning config. You can pass additional information in the config, for example the sparsity level you want to prune to by adding a key to the config. You can then access this additional information when you update the masks. + +### Implementing a Pruner + +If you want to prune weights using a different pruning criteria than saliency, you'll need to implement your own pruner. + +To do this, we need to extend a `BaseStructuredSparsifier` with a custom `update_mask` function. + +This `update_mask` function contains the user logic for picking what weights to prune. + +One common pruning criteria is to use the **saliency** of a row, which is defined as the sum of all the L1 norms of the weights in the row. +The idea is to remove the weights that are small, since they wouldn't contribute much to the final prediction. + +Below we can see an implemented Saliency Pruner + +```python +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune filters based on the saliency + The saliency for a filter is given by the sum of the L1 norms of all of its weights + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other keys in pruning config are present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False + +``` + +### Writing Custom Patterns and Pruning Functions for Structured Pruning +If you're working with linear/conv2d layers, it's very probable that you just need to add an entry to the pattern dict mapping your pattern to an existing prune_function. + +This is because there are many modules, for example **pooling** that behave the same way and do not need to be modified by the pruning code. + +```python +from torch.ao.pruning._experimental.pruner.prune_functions import prune_conv2d_activation_conv2d + +def prune_conv2d_pool_activation_conv2d( + c1: nn.Conv2d, + pool: nn.Module, + activation: Optional[Callable[[Tensor], Tensor]], + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + +# note how the pattern defined in the key will be passed to the pruning function as args +my_patterns = {(nn.Conv2d, nn.MaxPool2d, nn.ReLU, nn.Conv2d): prune_conv2d_activation_conv2d} + +pruning_patterns = _get_default_structured_pruning_patterns() +pruning_patterns.update(my_patterns) + +pruner = SaliencyPruner({}, patterns=pruning_patterns) +``` +However, there are also modules like batch norm, which will not work properly without being pruned as well. In this instance, you would need to write a custom pruning function in order to handle that logic properly. + +You can see the implemented pruning functions [here](https://github.com/pytorch/pytorch/blob/master/torch/ao/pruning/_experimental/pruner/prune_functions.py) for examples. Please feel free to open a PR so we get a complete set of the patterns and pruning functions. diff --git a/torchao/sparsity/prototype/pruner/__init__.py b/torchao/sparsity/prototype/pruner/__init__.py new file mode 100644 index 0000000000..6f017aa9e2 --- /dev/null +++ b/torchao/sparsity/prototype/pruner/__init__.py @@ -0,0 +1,8 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier +from .parametrization import ( + FakeStructuredSparsity, + BiasHook, +) +from .saliency_pruner import SaliencyPruner +from .lstm_saliency_pruner import LSTMSaliencyPruner +from .FPGM_pruner import FPGMPruner diff --git a/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py b/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py new file mode 100644 index 0000000000..a1a34b77f8 --- /dev/null +++ b/torchao/sparsity/prototype/pruner/base_structured_sparsifier.py @@ -0,0 +1,310 @@ +from itertools import chain +from operator import getitem +import torch +import torch.nn.functional as F +from torch import nn +from torch.fx import symbolic_trace +from torch.nn.utils import parametrize +from typing import Type, Set, Dict, Callable, Tuple, Optional, Union + +from torchao.sparsity.prototype import BaseSparsifier +from .parametrization import FakeStructuredSparsity, BiasHook, module_contains_param +from .match_utils import apply_match, MatchAllNode +from .prune_functions import ( + prune_linear, + prune_linear_linear, + prune_linear_activation_linear, + prune_conv2d, + prune_conv2d_conv2d, + prune_conv2d_activation_conv2d, + prune_conv2d_activation_pool_conv2d, + prune_conv2d_pool_activation_conv2d, + prune_conv2d_pool_flatten_linear, + prune_lstm_output_linear, + prune_lstm_output_layernorm_linear, +) + + +def _get_supported_structured_pruning_modules(): + SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given + nn.Linear, + nn.Conv2d, + nn.LSTM, + } + return SUPPORTED_STRUCTURED_PRUNING_MODULES + + +def _get_supported_activation_functions(): + SUPPORTED_ACTIVATION_FUNCTIONS = { + F.relu, + F.rrelu, + F.hardtanh, + F.relu6, + F.sigmoid, + F.hardsigmoid, + F.tanh, + F.silu, + F.mish, + F.hardswish, + F.elu, + F.celu, + F.selu, + F.hardshrink, + F.leaky_relu, + F.logsigmoid, + F.softplus, + F.prelu, + F.softsign, + F.tanhshrink, + F.gelu, + } + return SUPPORTED_ACTIVATION_FUNCTIONS + + +def _get_supported_activation_modules(): + SUPPORTED_ACTIVATION_MODULES = { + nn.ReLU, + nn.RReLU, + nn.Hardtanh, + nn.ReLU6, + nn.Sigmoid, + nn.Hardsigmoid, + nn.Tanh, + nn.SiLU, + nn.Mish, + nn.Hardswish, + nn.ELU, + nn.CELU, + nn.SELU, + nn.Hardshrink, + nn.LeakyReLU, + nn.LogSigmoid, + nn.Softplus, + nn.PReLU, + nn.Softsign, + nn.Tanhshrink, + nn.GELU, + } + return SUPPORTED_ACTIVATION_MODULES + + +def _get_default_structured_pruning_patterns() -> Dict[ + Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], +]: + """ + Returns the patterns for conv2d / linear conversion for each element in the activation functions/modules defined above. + """ + patterns: Dict[ + Tuple[Union[Type[nn.Module], Callable, MatchAllNode, str], ...], + Callable[..., None], + ] = { + # linear -> linear + (nn.Linear, "output"): prune_linear, + (nn.Linear, nn.Linear): prune_linear_linear, + # conv2d -> conv2d + (nn.Conv2d, "output"): prune_conv2d, + (nn.Conv2d, nn.Conv2d): prune_conv2d_conv2d, + # TODO LSTM Structured pruning does not support returned state currently. + # Should find a way to explicitly match getitem(0) instead of getitem. + # This will also require changing the pruning function. + # lstm -> getitem(0) -> linear + (nn.LSTM, getitem, nn.Linear): prune_lstm_output_linear, + # lstm -> getitem(0) -> layernorm -> linear + (nn.LSTM, getitem, nn.LayerNorm, nn.Linear): prune_lstm_output_layernorm_linear, + } + + for activation in chain( + _get_supported_activation_functions(), _get_supported_activation_modules() + ): + patterns.update( + { + # linear -> activation -> linear + (nn.Linear, activation, nn.Linear): prune_linear_activation_linear, + # conv2d -> activation -> conv2d + (nn.Conv2d, activation, nn.Conv2d): prune_conv2d_activation_conv2d, + # conv2d -> activation -> pool -> conv2d + ( + nn.Conv2d, + activation, + nn.AvgPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.avg_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + nn.MaxPool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + ( + nn.Conv2d, + activation, + F.max_pool2d, + nn.Conv2d, + ): prune_conv2d_activation_pool_conv2d, + # conv2d -> pool -> activation -> conv2d + ( + nn.Conv2d, + nn.AvgPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.avg_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + nn.MaxPool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + ( + nn.Conv2d, + F.max_pool2d, + activation, + nn.Conv2d, + ): prune_conv2d_pool_activation_conv2d, + # conv2d -> adaptive pool -> flatten -> linear + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveAvgPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + nn.Flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + ( + nn.Conv2d, + nn.AdaptiveMaxPool2d, + torch.flatten, + nn.Linear, + ): prune_conv2d_pool_flatten_linear, + } + ) + return patterns + + +class BaseStructuredSparsifier(BaseSparsifier): + r"""Base class for structured pruning. + + Abstract methods that need to be implemented: + - update_mask: Function to compute a new mask for all keys in the + `groups` attribute. + + Args: + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + """ + + def __init__(self, defaults, patterns=None): + super().__init__(defaults) + if patterns is None: + patterns = _get_default_structured_pruning_patterns() + self.patterns = patterns + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Optional[Set[Type]] = None, + ) -> None: + if SUPPORTED_MODULES is None: + SUPPORTED_MODULES = _get_supported_structured_pruning_modules() + super().make_config_from_model(model, SUPPORTED_MODULES=SUPPORTED_MODULES) + + def _prepare(self, *args, **kwargs) -> None: + r"""This function will attach the FakeStructuredSparsity parameterizations + and BiasHooks at the appropriate points in the model. + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeStructuredSparsity) + tensor = getattr(module, tensor_name) + + mask = config.get( + "mask", + torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device), + ) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + # if linear / conv, we add in bias hooks + if isinstance(module, (nn.Linear, nn.Conv2d)): + prune_bias = config.get("prune_bias", True) + if module.bias is not None: + module.register_parameter( + "_bias", nn.Parameter(module.bias.detach()) + ) + module.bias = None + module.prune_bias = prune_bias + + module.register_forward_hook( + BiasHook(module.parametrizations.weight[0], prune_bias) + ) + + def prune(self) -> None: + r""" + This function will FX symbolically trace the model and then find instances of the patterns + defined in self.patterns (by default SUPPORTED_STRUCTURED_PRUNING_PATTERNS ). + + For each pattern, it will apply to corresponding conversion function, which will modify the output + and input size expected by the modules within the pattern + """ + + self.traced = symbolic_trace(self.model) + modules = dict(self.traced.named_modules()) + + # Right now we check for matches simply by iterating across all the patterns + # if this is slow we can store patterns in a trie-structure and modify this code for faster lookup + for node in self.traced.graph.nodes: + for pattern, convert_fn in self.patterns.items(): + matched = apply_match(modules, pattern, node, []) + if matched is None: + continue + + first_module = modules.get(node.target) + # check if first module exists and has appropriate parameterization, otherwise skip + if ( + first_module is not None + and parametrize.is_parametrized(first_module) + and module_contains_param(first_module, FakeStructuredSparsity) + ): + convert_block = [] + for node in matched: + if node.op == "call_module": + convert_block.append(modules.get(node.target)) + elif node.op == "call_function": + convert_block.append(node.target) + convert_fn(*convert_block) + + for module in self.traced.modules(): + if module_contains_param(module, FakeStructuredSparsity): + raise Exception( + f"Error: {module} still contains FakeStructuredSparsity parametrizations!" + ) + + self.traced.graph.lint() + self.traced.recompile() + return self.traced diff --git a/torchao/sparsity/prototype/pruner/images/prune_1.png b/torchao/sparsity/prototype/pruner/images/prune_1.png new file mode 100644 index 0000000000..f7f4875922 Binary files /dev/null and b/torchao/sparsity/prototype/pruner/images/prune_1.png differ diff --git a/torchao/sparsity/prototype/pruner/images/prune_2.png b/torchao/sparsity/prototype/pruner/images/prune_2.png new file mode 100644 index 0000000000..5aad9d0451 Binary files /dev/null and b/torchao/sparsity/prototype/pruner/images/prune_2.png differ diff --git a/torchao/sparsity/prototype/pruner/images/prune_3.png b/torchao/sparsity/prototype/pruner/images/prune_3.png new file mode 100644 index 0000000000..1af2c3cb4e Binary files /dev/null and b/torchao/sparsity/prototype/pruner/images/prune_3.png differ diff --git a/torchao/sparsity/prototype/pruner/images/prune_4.png b/torchao/sparsity/prototype/pruner/images/prune_4.png new file mode 100644 index 0000000000..fe7586edc1 Binary files /dev/null and b/torchao/sparsity/prototype/pruner/images/prune_4.png differ diff --git a/torchao/sparsity/prototype/pruner/images/prune_5.png b/torchao/sparsity/prototype/pruner/images/prune_5.png new file mode 100644 index 0000000000..6bd92544d5 Binary files /dev/null and b/torchao/sparsity/prototype/pruner/images/prune_5.png differ diff --git a/torchao/sparsity/prototype/pruner/images/prune_6.png b/torchao/sparsity/prototype/pruner/images/prune_6.png new file mode 100644 index 0000000000..aeb1a718b0 Binary files /dev/null and b/torchao/sparsity/prototype/pruner/images/prune_6.png differ diff --git a/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py b/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py new file mode 100644 index 0000000000..4a0d74d6dc --- /dev/null +++ b/torchao/sparsity/prototype/pruner/lstm_saliency_pruner.py @@ -0,0 +1,48 @@ +from typing import cast + +import torch +from .base_structured_sparsifier import BaseStructuredSparsifier, FakeStructuredSparsity + +class LSTMSaliencyPruner(BaseStructuredSparsifier): + """ + Prune packed LSTM weights based on saliency. + For each layer {k} inside a LSTM, we have two packed weight matrices + - weight_ih_l{k} + - weight_hh_l{k} + + These tensors pack the weights for the 4 linear layers together for efficiency. + + [W_ii | W_if | W_ig | W_io] + + Pruning this tensor directly will lead to weights being misassigned when unpacked. + To ensure that each packed linear layer is pruned the same amount: + 1. We split the packed weight into the 4 constituent linear parts + 2. Update the mask for each individual piece using saliency individually + + This applies to both weight_ih_l{k} and weight_hh_l{k}. + """ + + def update_mask(self, module, tensor_name, **kwargs): + weights = getattr(module, tensor_name) + + for p in getattr(module.parametrizations, tensor_name): + if isinstance(p, FakeStructuredSparsity): + mask = cast(torch.Tensor, p.mask) + + # select weights based on magnitude + if weights.dim() <= 1: + raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!") + # take norm over all but first dim + dims = tuple(range(1, weights.dim())) + saliency = weights.norm(dim=dims, p=1) + + # handle weights in 4 groups + split_size = len(mask) // 4 + masks = torch.split(mask, split_size) + saliencies = torch.split(saliency, split_size) + + for keep_mask, sal in zip(masks, saliencies): + # mask smallest k values to be removed + k = int(len(keep_mask) * kwargs["sparsity_level"]) + prune = sal.topk(k, largest=False, sorted=False).indices + keep_mask.data[prune] = False # modifies underlying p.mask directly diff --git a/torchao/sparsity/prototype/pruner/match_utils.py b/torchao/sparsity/prototype/pruner/match_utils.py new file mode 100644 index 0000000000..d0f7a9f629 --- /dev/null +++ b/torchao/sparsity/prototype/pruner/match_utils.py @@ -0,0 +1,59 @@ +""" +Contains utility functions to check if a pattern is in the graph and return the matching nodes +""" +import torch +from torch import nn +from torch.ao.quantization.utils import ( + MatchAllNode, +) +from torch.fx import Node +from torch.nn.utils import parametrize +from typing import Any, Dict, List, Optional, Tuple, Union + +def _match(modules: Dict[str, nn.ModuleDict], node: Node, current: Union[nn.Module, Any]) -> bool: + r""" + checks to see if a single node of a pattern matches + """ + if isinstance(current, type) and issubclass(current, MatchAllNode): + return True + if not isinstance(node, Node): + return False + if isinstance(current, type) and issubclass(current, torch.nn.Module): + return ( + node.op == "call_module" + and parametrize.type_before_parametrizations(modules[node.target]) + == current + ) + elif callable(current): + return node.op == "call_function" and node.target is current + elif isinstance(current, str): + return node.target == current + return False + +def apply_match( + modules: Dict[str, nn.ModuleDict], + pattern: Union[Tuple[Any], Any], + node: Node, + matched_node_pattern: List[Node], +) -> Optional[List[Node]]: + r""" + This function will return the matched nodes if the pattern matches the node given + If there is no match, it will return None + """ + if isinstance(pattern, tuple): + if len(pattern) == 1: + if _match(modules, node, pattern[0]): + return matched_node_pattern + [node] + + first, *rest = pattern + if _match(modules, node, first): + if rest is None: + return matched_node_pattern + [node] + + for user in node.users: + return apply_match( + modules, tuple(rest), user, matched_node_pattern + [node] + ) + elif _match(modules, node, pattern): + return [node] + return None diff --git a/torchao/sparsity/prototype/pruner/parametrization.py b/torchao/sparsity/prototype/pruner/parametrization.py new file mode 100644 index 0000000000..df94f7093b --- /dev/null +++ b/torchao/sparsity/prototype/pruner/parametrization.py @@ -0,0 +1,59 @@ +import torch +from torch import nn +from torch.nn.utils.parametrize import is_parametrized + + +def module_contains_param(module, parametrization): + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() + ) + return False + + +# Structured Pruning Parameterizations +class FakeStructuredSparsity(nn.Module): + r""" + Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to + the 'weight' or any other parameter that requires a mask. + + Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + assert isinstance(self.mask, torch.Tensor) + assert self.mask.shape[0] == x.shape[0] + shape = [1] * len(x.shape) + shape[0] = -1 + return self.mask.reshape(shape) * x + + def state_dict(self, *args, **kwargs): + # avoid double saving masks + return {} + + +class BiasHook: + def __init__(self, parametrization, prune_bias): + self.param = parametrization + self.prune_bias = prune_bias + + def __call__(self, module, input, output): + + if getattr(module, "_bias", None) is not None: + bias = module._bias.data + if self.prune_bias: + bias[~self.param.mask] = 0 + + # reshape bias to broadcast over output dimensions + idx = [1] * len(output.shape) + idx[1] = -1 + bias = bias.reshape(idx) + + output += bias + return output diff --git a/torchao/sparsity/prototype/pruner/prune_functions.py b/torchao/sparsity/prototype/pruner/prune_functions.py new file mode 100644 index 0000000000..a75c09cc30 --- /dev/null +++ b/torchao/sparsity/prototype/pruner/prune_functions.py @@ -0,0 +1,475 @@ +""" +Collection of conversion functions for linear / conv2d structured pruning +Also contains utilities for bias propagation +""" +from typing import cast, List, Optional, Callable, Tuple + +import torch +from torch import nn, Tensor +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import ParametrizationList +from .parametrization import FakeStructuredSparsity, BiasHook + +# BIAS PROPAGATION +def _remove_bias_handles(module: nn.Module) -> None: + if hasattr(module, "_forward_hooks"): + bias_hooks: List[int] = [] + for key, hook in module._forward_hooks.items(): + if isinstance(hook, BiasHook): + bias_hooks.append(key) + + for key in bias_hooks: + del module._forward_hooks[key] + + +def _get_adjusted_next_layer_bias( + next_layer: nn.Module, pruned_biases: Tensor, mask: Tensor +) -> nn.Parameter: + r"""Returns new adjusted bias for the second supported module""" + if parametrize.is_parametrized(next_layer): + # need to access original weight + parametrization_dict = cast(nn.ModuleDict, next_layer.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + next_weight = weight_parameterizations.original + else: + next_weight = cast(Tensor, next_layer.weight) + + scaling_weight = next_weight[:, ~mask] + if isinstance(next_layer, nn.Conv2d): # checking for Conv2d + # Propagating first layer pruned biases and calculating the new second layer bias + # involves more steps since the Conv2d scaling weight has extra dimensions, + # so adding bias involves broadcasting, logically: + # for each channel k in range(oC): + # scaled_biases = sum(first_bias[pruned_idx] @ next_weight[k, pruned_idx, :, :].T) + # new_next_bias[k] = old_next_bias[k] + scaled_biases + scaling_product = torch.matmul( + pruned_biases.reshape(1, -1), torch.transpose(scaling_weight, 1, 2) + ) + sum_range = list(range(len(scaling_product.shape)))[ + 1: + ] # all but the first dimension + scaled_biases = torch.sum(scaling_product, sum_range) + elif isinstance(next_layer, nn.Linear): # Linear + scaled_biases = torch.matmul( + pruned_biases, torch.transpose(scaling_weight, 0, 1) + ) # recall b2_new = b1 @ w2.T + b2 + else: + raise NotImplementedError(f"Type {type(next_layer)} not supported yet.") + + if ( + parametrize.is_parametrized(next_layer) + and getattr(next_layer, "_bias", None) is not None + ): # next_layer is parametrized & has original bias ._bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer._bias) + elif ( + not parametrize.is_parametrized(next_layer) and next_layer.bias is not None + ): # next_layer not parametrized & has .bias + adjusted_bias = nn.Parameter(scaled_biases + next_layer.bias) + else: # next_layer has no bias + adjusted_bias = nn.Parameter(scaled_biases) + return adjusted_bias + + +def _prune_module_bias(module: nn.Module, mask: Tensor) -> None: + r"""Applies mask to given modules bias""" + # prune bias along with weights, discard pruned indices of bias + original_bias = cast(Tensor, getattr(module, "_bias", module.bias)) + if original_bias is not None: + module.bias = nn.Parameter(original_bias[mask]) + + # remove _bias parameter + if hasattr(module, "_bias"): + delattr(module, "_bias") + + +def _propogate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]: + r""" + In the case that we need to propagate biases, this function will return the biases we need + """ + # set current module bias + if module.bias is not None: + module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) + elif getattr(module, "_bias", None) is not None: + module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) + + # get pruned biases to propagate to subsequent layer + if getattr(module, "_bias", None) is not None: + pruned_biases = cast(Tensor, module._bias)[~mask] + else: + pruned_biases = None + + if hasattr(module, "_bias"): + delattr(module, "_bias") + + return pruned_biases + + +# LINEAR +def _prune_linear_helper(linear: nn.Linear) -> Tensor: + # expects linear to be a parameterized linear module + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(linear, "weight", leave_parametrized=True) + linear.weight = nn.Parameter(linear.weight[mask]) # type: ignore[possibly-undefined] + linear.out_features = linear.weight.shape[0] + _remove_bias_handles(linear) + + return mask + + +def prune_linear(linear: nn.Linear) -> None: + mask = _prune_linear_helper(linear) + if getattr(linear, "prune_bias", False): + _prune_module_bias(linear, mask) + + +def prune_linear_linear(linear1: nn.Linear, linear2: nn.Linear) -> None: + prune_linear_activation_linear(linear1, None, linear2) + + +def prune_linear_activation_linear( + linear1: nn.Linear, + activation: Optional[Callable[[Tensor], Tensor]], + linear2: nn.Linear, +): + mask = _prune_linear_helper(linear1) + if getattr(linear1, "prune_bias", False): + _prune_module_bias(linear1, mask) + else: + pruned_biases = _propogate_module_bias(linear1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + linear2.bias = _get_adjusted_next_layer_bias(linear2, pruned_biases, mask) + + with torch.no_grad(): + if parametrize.is_parametrized(linear2): + parametrization_dict = cast(nn.ModuleDict, linear2.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + linear2.in_features = weight_parameterizations.original.shape[1] + else: + linear2.weight = nn.Parameter(linear2.weight[:, mask]) + linear2.in_features = linear2.weight.shape[1] + + +# CONV2D +def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor: + parametrization_dict = cast(nn.ModuleDict, conv2d.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d, "weight", leave_parametrized=True) + conv2d.weight = nn.Parameter(conv2d.weight[mask]) # type: ignore[possibly-undefined] + conv2d.out_channels = conv2d.weight.shape[0] + + _remove_bias_handles(conv2d) + return mask + + +def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None: + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + with torch.no_grad(): + parametrize.remove_parametrizations(conv2d_1, "weight", leave_parametrized=True) + + if getattr(conv2d_1, "_bias", None) is not None: + if ( + conv2d_1.bias is not None + ): # conv2d_1 has original bias and bias propagated from previous layer + new_bias = torch.zeros(conv2d_1.bias.shape) + new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] + # adjusted bias that to keep in conv2d_1 + new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] + # pruned biases that are kept instead of propagated + conv2d_1.bias = nn.Parameter(new_bias) + else: # conv2d_1 has only original bias + conv2d_1.bias = nn.Parameter(cast(Tensor, conv2d_1._bias)) + else: + # no original bias, only propagated bias + if ( + conv2d_1.bias is not None + ): # conv2d_1 has bias propagated from previous layer + conv2d_1.bias.data[~mask] = 0 # type: ignore[possibly-undefined] + + if hasattr(conv2d_1, "_bias"): + delattr(conv2d_1, "_bias") + + +def prune_conv2d(conv2d: nn.Conv2d) -> None: + mask = _prune_conv2d_helper(conv2d) + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + + +def prune_conv2d_conv2d(conv2d_1: nn.Conv2d, conv2d_2: nn.Conv2d) -> None: + prune_conv2d_activation_conv2d(conv2d_1, None, conv2d_2) + + +def prune_conv2d_activation_conv2d( + conv2d_1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + conv2d_2: nn.Conv2d, +): + r""" + Fusion Pattern for conv2d -> some activation module / function -> conv2d layers + """ + parametrization_dict = cast(nn.ModuleDict, conv2d_1.parametrizations) + weight_parameterizations = cast(ParametrizationList, parametrization_dict.weight) + for p in weight_parameterizations: + if isinstance(p, FakeStructuredSparsity): + mask = cast(Tensor, p.mask) + + prune_bias = getattr(conv2d_1, "prune_bias", False) + if ( + hasattr(conv2d_2, "padding") + and cast(Tuple[int], conv2d_2.padding) > (0, 0) + and (conv2d_1.bias is not None or getattr(conv2d_1, "_bias", None) is not None) + ): + prune_conv2d_padded(conv2d_1) + else: + mask = _prune_conv2d_helper(conv2d_1) + if prune_bias: + _prune_module_bias(conv2d_1, mask) + else: + pruned_biases = _propogate_module_bias(conv2d_1, mask) + if pruned_biases is not None: + if activation: + pruned_biases = activation(pruned_biases) + conv2d_2.bias = _get_adjusted_next_layer_bias( + conv2d_2, pruned_biases, mask + ) + + if ( + not ( + hasattr(conv2d_2, "padding") + and cast(Tuple[int], conv2d_2.padding) > (0, 0) + ) + or conv2d_1.bias is None + ): + with torch.no_grad(): + if parametrize.is_parametrized(conv2d_2): + parametrization_dict = cast( + nn.ModuleDict, conv2d_2.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, mask] + ) + conv2d_2.in_channels = weight_parameterizations.original.shape[1] + else: + conv2d_2.weight = nn.Parameter(conv2d_2.weight[:, mask]) + conv2d_2.in_channels = conv2d_2.weight.shape[1] + + +def prune_conv2d_pool_activation_conv2d( + c1: nn.Conv2d, + pool: nn.Module, + activation: Optional[Callable[[Tensor], Tensor]], + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_activation_pool_conv2d( + c1: nn.Conv2d, + activation: Optional[Callable[[Tensor], Tensor]], + pool: nn.Module, + c2: nn.Conv2d, +) -> None: + prune_conv2d_activation_conv2d(c1, activation, c2) + + +def prune_conv2d_pool_flatten_linear( + conv2d: nn.Conv2d, + pool: nn.Module, + flatten: Optional[Callable[[Tensor], Tensor]], + linear: nn.Linear, +) -> None: + mask = _prune_conv2d_helper(conv2d) + + # We map the pruned indices of the Conv2d output to the flattened indices of the Linear following the Flatten layer. + # we determine the flattening scale (h * w), and readjust `first_pruned_indices` + # (each idx maps to range idx * h * w to (idx+1) * h * w), `first_valid_indices`, + # and `pruned_biases` (repeat each bias by h * w). + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + linear_ic = weight_parameterizations.original.shape[1] + else: + linear_ic = linear.weight.shape[1] + + conv2d_oc = len(mask) + assert ( + linear_ic % conv2d_oc == 0 + ), f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + + flatten_scale = linear_ic // conv2d_oc + flattened_mask = torch.tensor( + [[val] * flatten_scale for val in mask], dtype=torch.bool, device=mask.device + ).flatten() + + if getattr(conv2d, "prune_bias", False): + _prune_module_bias(conv2d, mask) + else: + pruned_biases = cast(Tensor, _propogate_module_bias(conv2d, mask)) + flattened_pruned_biases = torch.tensor( + [[bias] * flatten_scale for bias in pruned_biases], device=mask.device + ).flatten() + linear.bias = _get_adjusted_next_layer_bias( + linear, flattened_pruned_biases, flattened_mask + ) + + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast(nn.ModuleDict, linear.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, flattened_mask] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, flattened_mask]) + linear.in_features = linear.weight.shape[1] + + +def prune_lstm_output_linear( + lstm: nn.LSTM, getitem: Callable, linear: nn.Linear +) -> None: + prune_lstm_output_layernorm_linear(lstm, getitem, None, linear) + + +def prune_lstm_output_layernorm_linear( + lstm: nn.LSTM, + getitem: Callable, + layernorm: Optional[nn.LayerNorm], + linear: nn.Linear, +) -> None: + for i in range(lstm.num_layers): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_ih_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_ih_l{i}", leave_parametrized=True + ) + setattr( + lstm, + f"weight_ih_l{i}", + nn.Parameter(getattr(lstm, f"weight_ih_l{i}")[mask]), + ) + setattr( + lstm, + f"bias_ih_l{i}", + nn.Parameter(getattr(lstm, f"bias_ih_l{i}")[mask]), + ) + + if parametrize.is_parametrized(lstm, f"weight_hh_l{i}"): + parametrization_dict = cast(nn.ModuleDict, lstm.parametrizations) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict[f"weight_hh_l{i}"] + ) + mask = weight_parameterizations[0].mask + + with torch.no_grad(): + parametrize.remove_parametrizations( + lstm, f"weight_hh_l{i}", leave_parametrized=True + ) + # splitting out hidden-hidden masks + W_hi, W_hf, W_hg, W_ho = torch.split( + getattr(lstm, f"weight_hh_l{i}"), lstm.hidden_size + ) + M_hi, M_hf, M_hg, M_ho = torch.split(mask, lstm.hidden_size) + + # resize each individual weight separately + W_hi = W_hi[M_hi][:, M_hi] + W_hf = W_hf[M_hf][:, M_hf] + W_hg = W_hg[M_hg][:, M_hg] + W_ho = W_ho[M_ho][:, M_ho] + + # concat, use this as new weight + new_weight = torch.cat((W_hi, W_hf, W_hg, W_ho)) + setattr(lstm, f"weight_hh_l{i}", nn.Parameter(new_weight)) + setattr( + lstm, + f"bias_hh_l{i}", + nn.Parameter(getattr(lstm, f"bias_hh_l{i}")[mask]), + ) + + # If this is the final layer, then we need to prune linear layer columns + if i + 1 == lstm.num_layers: + lstm.hidden_size = int(M_hi.sum()) + with torch.no_grad(): + if parametrize.is_parametrized(linear): + parametrization_dict = cast( + nn.ModuleDict, linear.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, parametrization_dict.weight + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + linear.in_features = weight_parameterizations.original.shape[1] + else: + linear.weight = nn.Parameter(linear.weight[:, M_ho]) + linear.in_features = linear.weight.shape[1] + + # if layernorm module, prune weight and bias + if layernorm is not None: + layernorm.normalized_shape = (linear.in_features,) + layernorm.weight = nn.Parameter(layernorm.weight[M_ho]) + layernorm.bias = nn.Parameter(layernorm.bias[M_ho]) + + # otherwise need to prune the columns of the input of the next LSTM layer + else: + with torch.no_grad(): + if parametrize.is_parametrized(lstm, f"weight_ih_l{i+1}"): + parametrization_dict = cast( + nn.ModuleDict, lstm.parametrizations + ) + weight_parameterizations = cast( + ParametrizationList, + getattr(parametrization_dict, f"weight_ih_l{i+1}"), + ) + + weight_parameterizations.original = nn.Parameter( + weight_parameterizations.original[:, M_ho] + ) + else: + next_layer_weight = getattr(lstm, f"weight_ih_l{i+1}") + setattr( + lstm, + f"weight_ih_l{i+1}", + nn.Parameter(next_layer_weight[:, M_ho]), + ) diff --git a/torchao/sparsity/prototype/pruner/saliency_pruner.py b/torchao/sparsity/prototype/pruner/saliency_pruner.py new file mode 100644 index 0000000000..f965fa647d --- /dev/null +++ b/torchao/sparsity/prototype/pruner/saliency_pruner.py @@ -0,0 +1,29 @@ +from .base_structured_sparsifier import BaseStructuredSparsifier + + +class SaliencyPruner(BaseStructuredSparsifier): + """ + Prune rows based on the saliency (L1 norm) of each row. + + This pruner works on N-Dimensional weight tensors. + For each row, we will calculate the saliency, whic is the sum the L1 norm of all weights in that row. + We expect that the resulting saliency vector has the same shape as our mask. + We then pick elements to remove until we reach the target sparsity_level. + """ + + def update_mask(self, module, tensor_name, **kwargs): + # tensor_name will give you the FQN, all other entries in sparse config is present in kwargs + weights = getattr(module, tensor_name) + mask = getattr(module.parametrizations, tensor_name)[0].mask + + # use negative weights so we can use topk (we prune out the smallest) + if weights.dim() <= 1: + raise Exception("Structured pruning can only be applied to a 2+dim weight tensor!") + saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) + assert saliency.shape == mask.shape + + num_to_pick = int(len(mask) * kwargs["sparsity_level"]) + prune = saliency.topk(num_to_pick).indices + + # Set the mask to be false for the rows we want to prune + mask.data[prune] = False diff --git a/torchao/sparsity/prototype/scheduler/__init__.py b/torchao/sparsity/prototype/scheduler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/sparsity/prototype/scheduler/base_scheduler.py b/torchao/sparsity/prototype/scheduler/base_scheduler.py new file mode 100644 index 0000000000..f102f351ea --- /dev/null +++ b/torchao/sparsity/prototype/scheduler/base_scheduler.py @@ -0,0 +1,159 @@ + +from functools import wraps +import warnings +import weakref + +from torchao.sparsity.prototype.sparsifier.base_sparsifier import BaseSparsifier + +__all__ = ["BaseScheduler"] + +class BaseScheduler: + + def __init__(self, sparsifier, last_epoch=-1, verbose=False): + + # Attach sparsifier + if not isinstance(sparsifier, BaseSparsifier): + raise TypeError(f'{type(sparsifier).__name__} is not an instance of torch.ao.pruning.BaseSparsifier') + self.sparsifier = sparsifier + + # Initialize epoch and base sparsity levels + + self.base_sl = [group['sparsity_level'] for group in sparsifier.groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `scheduler.step()` is called after + # `sparsifier.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `sparsifier.step()` has already been replaced, return. + return method + + # Keep a weak reference to the sparsifier instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 # type: ignore[union-attr] + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True # type: ignore[attr-defined] + return wrapper + + self.sparsifier.step = with_counter(self.sparsifier.step) # type: ignore[assignment] + self.sparsifier._step_count = 0 # type: ignore[attr-defined] + self._step_count: int = 0 + self.verbose = verbose + + # Housekeeping + self._get_sl_called_within_step: bool = False + + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the sparsifier. + """ + return {key: value for key, value in self.__dict__.items() if key != 'sparsifier'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_sl(self): + """ Return last computed sparsity level by current scheduler. + """ + return self._last_sl + + def get_sl(self): + # Compute sparsity level using chainable form of the scheduler + # Note: This method is not intended to be called directly, and is only + # used by the ".step" method. Use .get_last_sl() instead. + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.") + raise NotImplementedError + + def print_sl(self, is_verbose, group, sl, epoch=None): + """Display the current sparsity level. + """ + if is_verbose: + if epoch is None: + print(f'Adjusting sparsity level of group {group} to {sl:.4e}.') + else: + print(f'Epoch {epoch:5d}: adjusting sparsity level of group {group} to {sl:.4e}.') + + def __repr__(self): + format_string = self.__class__.__name__ + ' (' + format_string += '\n' + format_string += f'Sparsifier {self.sparsifier}\n' + format_string += f' base_sl: {self.base_sl}\n' + format_string += ')' + return format_string + + def step(self, epoch=None): + # Raise warning if trying to call scheduler step before the sparsifier. + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.sparsifier.step, "_with_counter"): + warnings.warn("Seems like `sparsifier.step()` has been overridden after sparsity scheduler " + "initialization. Please, make sure to call `sparsifier.step()` before " + "`scheduler.step()`.", UserWarning) + + # Just check if there were two first scheduler.step() calls before sparsifier.step() + elif self.sparsifier._step_count < 1: # type: ignore[attr-defined] + warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. " + "You have to make sure you run the sparsifier.step() BEFORE any " + "calls to the scheduler.step().", UserWarning) + self._step_count += 1 + + class _enable_get_sl_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_sl_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_sl_called_within_step = False + + with _enable_get_sl_call(self): + self.last_epoch += 1 + values = self.get_sl() + + for i, data in enumerate(zip(self.sparsifier.groups, values)): + param_group, sl = data + param_group['sparsity_level'] = sl + self.print_sl(self.verbose, i, sl, epoch) + + self._last_sl = [group['sparsity_level'] for group in self.sparsifier.groups] + self.sparsifier.enable_mask_update = True + + def _make_sure_a_list(self, var): + r"""Utility that extends it to the same length as the .groups, ensuring it is a list""" + n = len(self.sparsifier.groups) + if not isinstance(var, (list, tuple)): + return [var] * n + else: + if len(var) != n: + raise ValueError(f"Expected variable of length {n}, but got {len(var)}") + return list(var) # We want the result to be in a list, not tuple diff --git a/torchao/sparsity/prototype/scheduler/cubic_scheduler.py b/torchao/sparsity/prototype/scheduler/cubic_scheduler.py new file mode 100644 index 0000000000..76fc61daa2 --- /dev/null +++ b/torchao/sparsity/prototype/scheduler/cubic_scheduler.py @@ -0,0 +1,107 @@ +import warnings + +from .base_scheduler import BaseScheduler + +__all__ = ["CubicSL"] + +def _clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +class CubicSL(BaseScheduler): + r"""Sets the sparsity level of each parameter group to the final sl + plus a given exponential function. + + .. math:: + + s_i = s_f + (s_0 - s_f) \cdot \left( 1 - \frac{t - t_0}{n\Delta t} \right)^3 + + where :math:`s_i` is the sparsity at epoch :math:`t`, :math;`s_f` is the final + sparsity level, :math:`f(i)` is the function to be applied to the current epoch + :math:`t`, initial epoch :math:`t_0`, and final epoch :math:`t_f`. + :math:`\Delta t` is used to control how often the update of the sparsity level + happens. By default, + + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + init_sl (int, list): Initial level of sparsity + init_t (int, list): Initial step, when pruning starts + delta_t (int, list): Pruning frequency + total_t (int, list): Total number of pruning steps + initially_zero (bool, list): If True, sets the level of sparsity to 0 + before init_t (:math:`t_0`). Otherwise, the sparsity level before + init_t (:math:`t_0`) is set to init_sl(:math:`s_0`) + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + def __init__(self, + sparsifier, + init_sl=0.0, + init_t=0, + delta_t=10, + total_t=100, + initially_zero=False, + last_epoch=-1, + verbose=False + ): + self.sparsifier = sparsifier + + self.init_sl = self._make_sure_a_list(init_sl) + self.init_t = self._make_sure_a_list(init_t) + self.delta_t = self._make_sure_a_list(delta_t) + self.total_t = self._make_sure_a_list(total_t) + + self.initially_zero = self._make_sure_a_list(initially_zero) + + super().__init__(sparsifier, last_epoch, verbose) + + @staticmethod + def sparsity_compute_fn(s_0, s_f, t, t_0, dt, n, initially_zero=False): + r""""Computes the current level of sparsity. + + Based on https://arxiv.org/pdf/1710.01878.pdf + + Args: + s_0: Initial level of sparsity, :math:`s_i` + s_f: Target level of sparsity, :math:`s_f` + t: Current step, :math:`t` + t_0: Initial step, :math:`t_0` + dt: Pruning frequency, :math:`\Delta T` + n: Pruning steps, :math:`n` + initially_zero: Sets the level of sparsity to 0 before t_0. + If False, sets to s_0 + + Returns: + The sparsity level :math:`s_t` at the current step :math:`t` + """ + if initially_zero and t < t_0: + return 0 + s_t = s_f + (s_0 - s_f) * (1.0 - (t - t_0) / (dt * n)) ** 3 + s_t = _clamp(s_t, s_0, s_f) + return s_t + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.") + return [ + self.sparsity_compute_fn( + s_0=initial_sparsity, + s_f=final_sparsity, + t=self.last_epoch, + t_0=initial_epoch, + dt=delta_epoch, + n=interval_epochs, + initially_zero=initially_zero + ) for initial_sparsity, final_sparsity, initial_epoch, delta_epoch, interval_epochs, initially_zero in + zip( + self.init_sl, + self.base_sl, + self.init_t, + self.delta_t, + self.total_t, + self.initially_zero + ) + ] diff --git a/torchao/sparsity/prototype/scheduler/lambda_scheduler.py b/torchao/sparsity/prototype/scheduler/lambda_scheduler.py new file mode 100644 index 0000000000..a88d99a1f8 --- /dev/null +++ b/torchao/sparsity/prototype/scheduler/lambda_scheduler.py @@ -0,0 +1,47 @@ +import warnings + +from .base_scheduler import BaseScheduler + +__all__ = ["LambdaSL"] + +class LambdaSL(BaseScheduler): + """Sets the sparsity level of each parameter group to the final sl + times a given function. When last_epoch=-1, sets initial sl as zero. + Args: + sparsifier (BaseSparsifier): Wrapped sparsifier. + sl_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in sparsifier.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + Example: + >>> # Assuming sparsifier has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> # xdoctest: +SKIP + >>> scheduler = LambdaSL(sparsifier, sl_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, sparsifier, sl_lambda, last_epoch=-1, verbose=False): + self.sparsifier = sparsifier + + if not isinstance(sl_lambda, list) and not isinstance(sl_lambda, tuple): + self.sl_lambdas = [sl_lambda] * len(sparsifier.groups) + else: + if len(sl_lambda) != len(sparsifier.groups): + raise ValueError(f"Expected {len(sparsifier.groups)} lr_lambdas, but got {len(sl_lambda)}") + self.sl_lambdas = list(sl_lambda) + super().__init__(sparsifier, last_epoch, verbose) + + def get_sl(self): + if not self._get_sl_called_within_step: + warnings.warn( + "To get the last sparsity level computed by the scheduler, " + "please use `get_last_sl()`.") + return [base_sl * lmbda(self.last_epoch) + for lmbda, base_sl in zip(self.sl_lambdas, self.base_sl)] diff --git a/torchao/sparsity/prototype/sparsifier/__init__.py b/torchao/sparsity/prototype/sparsifier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/sparsity/prototype/sparsifier/base_sparsifier.py b/torchao/sparsity/prototype/sparsifier/base_sparsifier.py new file mode 100644 index 0000000000..1c210ace34 --- /dev/null +++ b/torchao/sparsity/prototype/sparsifier/base_sparsifier.py @@ -0,0 +1,353 @@ +import abc +import copy +from collections import defaultdict +from typing import Any, Dict, Optional, Set, Tuple, List, Type + +import torch +from torch import nn +from torch.nn.utils import parametrize +from torch.nn.utils.parametrize import type_before_parametrizations + +from .utils import ( + module_contains_param, + swap_module, + FakeSparsity, + get_arg_info_from_tensor_fqn, + module_to_fqn, +) + +__all__ = ["BaseSparsifier"] + +SUPPORTED_MODULES = {nn.Linear} + +KEYS_NOT_IN_STATE_DICT = ["module", "module_fqn", "tensor_name"] + +__all__ = ["BaseSparsifier"] + + +# TODO update desc with new config args +class BaseSparsifier(abc.ABC): + r"""Base class for all sparsifiers. + + Abstract methods that need to be implemented: + + - update_mask: Function to compute a new mask for all keys in the + `groups`. + + Args: + - model [nn.Module]: model to configure. The model itself is not saved + but used for the state_dict saving / loading. + - config [list]: configuration elements should be a dict map that includes + `tensor_fqn` of tensors to sparsify + - defaults [dict]: default configurations will be attached to the + configuration. Only the keys that don't exist in the `config` will + be updated. + + Example:: + + >>> # xdoctest: +SKIP("Can't instantiate abstract class BaseSparsifier with abstract method update_mask") + >>> config = [{'tensor_fqn': 'layer1.weight', 'tensor_fqn': 'linear2.weight2', 'sparsity_level': 0.5}] + >>> defaults = {'sparsity_level': 0.7} + >>> # model.layer1.weight will have `sparsity_level` = 0.7 (getting default) + >>> sparsifier = BaseSparsifier(config, defaults) + """ + + def __init__(self, defaults: Optional[Dict[str, Any]] = None): + super().__init__() + self.defaults: Dict[str, Any] = defaults or {} + + self.state: Dict[str, Dict] = defaultdict(dict) + self.groups: List[Dict[str, Any]] = [] + self.enable_mask_update = True + + def __getstate__(self) -> Dict[str, Any]: + return { + "defaults": self.defaults, + "state": self.state, + "groups": self.groups, + } + + def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: + self.__dict__.update(state) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for i, sparse_args in enumerate(self.groups): + module = sparse_args["module"] + format_string += "\n" + format_string += f"\tGroup {i}\n" + format_string += f"\t module: {module}\n" + for key in sorted(sparse_args.keys()): + if key == "module": + continue + format_string += f"\t {key}: {sparse_args[key]}\n" + format_string += ")" + return format_string + + def state_dict(self) -> Dict[str, Any]: + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains: + * state - current state of the sparsification. + * groups - a list containing all sparsity configuration groups + with the key 'tensor_fqn' specifying the path to the sparsified tensor within a model + + TODO: Need a clean way of loading the state of the "prepared" module + """ + + groups: List[Dict[str, Any]] = [ + dict( + filter( + lambda key_value: key_value[0] not in KEYS_NOT_IN_STATE_DICT, + mg.items(), + ) + ) + for mg in self.groups + ] + + return { + "state": self.state, + "groups": groups, + } + + def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True): + groups = copy.deepcopy(state_dict["groups"]) + states = state_dict["state"] + for tensor_fqn, s in states.items(): + arg_info = get_arg_info_from_tensor_fqn(self.model, tensor_fqn) + module = arg_info["module"] + tensor_name = arg_info["tensor_name"] + if strict and module is None: + raise RuntimeError(f"Error loading {tensor_fqn} into the model") + + found = False + for p in module.parametrizations[tensor_name]: + if isinstance(p, FakeSparsity): + found = True + break + if not found: + p = FakeSparsity(torch.ones(getattr(module, tensor_name).shape)) + parametrize.register_parametrization(module, tensor_name, p) + if s.get("mask", None) is not None: + mask = s.pop("mask") + p.mask = mask + + for mg in groups: + if mg["tensor_fqn"] == tensor_fqn: + mg.update(arg_info) + self.__setstate__({"state": states, "groups": groups}) + + def make_config_from_model( + self, + model: nn.Module, + SUPPORTED_MODULES: Set[Type] = SUPPORTED_MODULES, + ) -> None: + self.config = [] + stack = [model] + while stack: + module = stack.pop() + for name, child in module.named_children(): + if type(child) in SUPPORTED_MODULES: + module_fqn = module_to_fqn(model, child) + assert isinstance(module_fqn, str) # for mypy + self.config.append({"tensor_fqn": module_fqn + ".weight"}) + else: + stack.append(child) + + def prepare(self, model, config): + r"""Prepares a model, by adding the parametrizations. + + Note:: + + The model is modified inplace. If you need to preserve the original + model, use copy.deepcopy. + """ + self.model = model # TODO: Need to figure out how to load without this. + self.config = config + + # If no config -- try getting all the supported layers + if self.config is None: + self.make_config_from_model(model) + + # TODO: Remove the configuration by reference ('module') + for module_config in self.config: + assert isinstance(module_config, dict), ( + "config elements should be dicts not modules i.e.:" + "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" + ) + + assert isinstance(self.defaults, Dict) # for mypy + local_args = copy.deepcopy(self.defaults) + local_args.update(module_config) + + tensor_fqn = local_args.get("tensor_fqn", None) + assert tensor_fqn is not None, ( + "tensor_fqn is a required argument in the sparsity config which" + "replaces previous `module` and [module]`fqn` arguments" + ) + + # populate all information from tensor_fqn + info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) + + # check that whatever was put into local_args agrees with what was obtained + # from tensor_fqn + for key in info_from_tensor_fqn.keys(): + if key in local_args: + assert ( + info_from_tensor_fqn[key] == local_args[key] + or ( + key == "tensor_fqn" + and "." + info_from_tensor_fqn[key] == local_args[key] + ) + # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that + ), ( + f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ) + local_args.update(info_from_tensor_fqn) + self.groups.append(local_args) + self._prepare() + + def _prepare(self, *args, **kwargs): + r"""Adds mask parametrization to the layer weight""" + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrization = config.get("parametrization", FakeSparsity) + mask = config.get("mask", torch.ones_like(getattr(module, tensor_name))) + self.state[config["tensor_fqn"]]["mask"] = mask + parametrize.register_parametrization( + module, tensor_name, parametrization(mask) + ) + + def squash_mask( + self, + params_to_keep: Optional[Tuple[str, ...]] = None, + params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, + *args, + **kwargs, + ): + r"""Squashes the sparse masks into the appropriate tensors. + + If either the `params_to_keep` or `params_to_keep_per_layer` is set, + the module will have a `sparse_params` dict attached to it. + + Args: + params_to_keep: List of keys to save in the module or a dict + representing the modules and keys that will have + sparsity parameters saved + params_to_keep_per_layer: Dict to specify the params that should be + saved for specific layers. The keys in the dict + should be the module fqn, while the values should + be a list of strings with the names of the variables + to save in the `sparse_params` + + Examples: + >>> # xdoctest: +SKIP("locals are undefined") + >>> # Don't save any sparse params + >>> sparsifier.squash_mask() + >>> hasattr(model.submodule1, 'sparse_params') + False + + >>> # Keep sparse params per layer + >>> sparsifier.squash_mask( + ... params_to_keep_per_layer={ + ... 'submodule1.linear1': ('foo', 'bar'), + ... 'submodule2.linear42': ('baz',) + ... }) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'baz': 0.1} + + >>> # Keep sparse params for all layers + >>> sparsifier.squash_mask(params_to_keep=('foo', 'bar')) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24} + + >>> # Keep some sparse params for all layers, and specific ones for + >>> # some other layers + >>> sparsifier.squash_mask( + ... params_to_keep=('foo', 'bar'), + ... params_to_keep_per_layer={ + ... 'submodule2.linear42': ('baz',) + ... }) + >>> print(model.submodule1.linear1.sparse_params) + {'foo': 42, 'bar': 24} + >>> print(model.submodule2.linear42.sparse_params) + {'foo': 42, 'bar': 24, 'baz': 0.1} + """ + for config in self.groups: + module = config["module"] + tensor_name = config["tensor_name"] + parametrize.remove_parametrizations( + module, tensor_name, leave_parametrized=True + ) + sparse_params = {} + if params_to_keep is not None: + global_params = {k: config[k] for k in params_to_keep} + sparse_params.update(global_params) + if params_to_keep_per_layer is not None: + params = params_to_keep_per_layer.get(config["module_fqn"], None) + if params is not None: + per_layer_params = {k: config[k] for k in params} + sparse_params.update(per_layer_params) + if sparse_params: + # TODO handle multiple tensor being quantized on a single module, where to store sparse_params? + module.sparse_params = sparse_params + + def convert( + self, + module: nn.Module, + mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None, + inplace: bool = False, + parameterization: Type[nn.Module] = FakeSparsity, + ): + r"""Converts submodules in input module to a different module according to `mapping` + by calling `from_dense` method on the target module class + Args: + module: input module + mapping: a dictionary that maps from source module type to target + module type, can be overwritten to allow swapping user defined + Modules + inplace: carry out model transformations in-place, the original module + is mutated + """ + if mapping is None: + raise NotImplementedError("Need to auto generate mapping ") + if not inplace: + module = copy.deepcopy(module) + + reassign = {} + for name, mod in module.named_children(): + # leaf node + if ( + module_contains_param(mod, parameterization) + and type_before_parametrizations(mod) in mapping + ): + reassign[name] = swap_module(mod, mapping) + else: + # recurse + reassign[name] = self.convert( + mod, + mapping=mapping, + inplace=True, + parameterization=parameterization, + ) + + for key, value in reassign.items(): + module._modules[key] = value + + return module + + def step(self, use_path: bool = True) -> None: + if not self.enable_mask_update: + return + with torch.no_grad(): + for config in self.groups: + self.update_mask(**config) + + @abc.abstractmethod + def update_mask(self, module: nn.Module, tensor_name: str, **kwargs): + pass diff --git a/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py b/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py new file mode 100644 index 0000000000..4f44e81485 --- /dev/null +++ b/torchao/sparsity/prototype/sparsifier/nearly_diagonal_sparsifier.py @@ -0,0 +1,55 @@ +import torch + +from . import base_sparsifier + + +class NearlyDiagonalSparsifier(base_sparsifier.BaseSparsifier): + r"""Nearly Diagonal Sparsifier + + This sparsifier creates a nearly diagonal mask to be applied to the weight matrix. + Nearly Diagonal Matrix is a matrix that contains non-zero elements near the diagonal and the rest are zero. + An example of a nearly diagonal matrix with degree (or nearliness) 3 and 5 are follows respectively. + 1 1 0 0 1 1 1 0 + 1 1 1 0 1 1 1 1 + 0 1 1 1 1 1 1 1 + 0 0 1 1 0 1 1 1 + Note that a nearly diagonal matrix with degree 1 is just a matrix with main diagonal populated + + This sparsifier is controlled by one variable: + 1. `nearliness` defines the number of non-zero diagonal lines that are closest to the main diagonal. + Currently - supports only odd number + + Note: + This can be accelerated (vectorized) once the Spdiagonal feature (PR: #78439) is landed or the banded matrix + feature is landed: https://stackoverflow.com/questions/52463972/generating-banded-matrices-using-numpy + + Args: + nearliness: The degree of nearliness (default = 1) + + """ + def __init__(self, nearliness: int = 1): + defaults = {'nearliness': nearliness} + super().__init__(defaults=defaults) + + def update_mask(self, module, tensor_name, nearliness, + **kwargs): + mask = getattr(module.parametrizations, tensor_name)[0].mask + mask.data = torch.zeros_like(mask) + if nearliness <= 0: + return + + tensor = getattr(module, tensor_name) + height, width = tensor.shape + + if nearliness % 2 == 0: + raise ValueError("nearliness can only be an odd number") + dist_to_diagonal = nearliness // 2 + # check + if dist_to_diagonal >= min(height, width): + raise ValueError("nearliness cannot be larger than the dimensions of tensor.") + + for row in range(0, height): + # Bounds of entries that needs to be set to 1 + low = max(0, row - dist_to_diagonal) + high = min(width, row + dist_to_diagonal + 1) + mask[row, low:high].fill_(1) diff --git a/torchao/sparsity/prototype/sparsifier/utils.py b/torchao/sparsity/prototype/sparsifier/utils.py new file mode 100644 index 0000000000..98f489904c --- /dev/null +++ b/torchao/sparsity/prototype/sparsifier/utils.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, Optional, Type +from torch.nn.utils.parametrize import type_before_parametrizations, is_parametrized +from itertools import chain + +from torch import nn + +__all__ = [ + "module_contains_param", + "swap_module", + "module_to_fqn", + "fqn_to_module", + "get_arg_info_from_tensor_fqn", + "FakeSparsity", +] + + +def module_contains_param(module: nn.Module, parametrization: Type[nn.Module]) -> bool: + if is_parametrized(module): + # see if any of the module tensors have a parametriztion attached that matches the one passed in + return any( + any(isinstance(param, parametrization) for param in param_list) + for key, param_list in module.parametrizations.items() # type: ignore[union-attr,operator] + ) + return False + + +def swap_module( + mod: nn.Module, mapping: Dict[Type[nn.Module], Type[nn.Module]] +) -> nn.Module: + r"""Swaps the module using from_dense according to the mapping passed in. + Args: + mod: input module + mapping: a dictionary that maps from nn module to sparse nn module + Return: + The corresponding sparse module of `mod` according to mapping, created using from_dense + """ + if type_before_parametrizations(mod) in mapping: + sparse_mod = mapping[type_before_parametrizations(mod)] + + # TODO Fix this typing, as Type[Module] has no attribute "from_dense" + new_mod = sparse_mod.from_dense(mod) # type: ignore[attr-defined] + + # Preserve module's pre forward hooks. They'll be called on quantized input + for pre_hook_fn in mod._forward_pre_hooks.values(): + new_mod.register_forward_pre_hook(pre_hook_fn) + # Preserve module's post forward hooks except _observer_forward_hook + # After convert they'll work with quantized output + for hook_fn in mod._forward_hooks.values(): + new_mod.register_forward_hook(hook_fn) + + # respect device affinity when swapping modules + devices = {p.device for p in chain(mod.parameters(), mod.buffers())} + assert len(devices) <= 1, ( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) + device = next(iter(devices)) if len(devices) > 0 else None + if device: + new_mod.to(device) + + return new_mod + + else: + return mod + + +def module_to_fqn( + model: nn.Module, module: nn.Module, prefix: str = "" +) -> Optional[str]: + """ + Returns the fqn for a module or None if module not a descendent of model. + """ + if module is model: + return "" + for name, child in model.named_children(): + fqn = module_to_fqn(child, module, ".") + if isinstance(fqn, str): + return prefix + name + fqn + return None + + +def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]: + """ + Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path` + doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors. + """ + if path != "": + for name in path.split("."): + model = getattr(model, name, None) + return model + + +def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]: + """ + Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name + """ + # string manip to split tensor_fqn into module_fqn and tensor_name + # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight' + # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight' + tensor_name = tensor_fqn.split(".")[-1] + module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)] + + module = fqn_to_module(model, module_fqn) + + return { + "module_fqn": module_fqn, + "module": module, + "tensor_name": tensor_name, + "tensor_fqn": tensor_fqn, + } + + +# Parametrizations +class FakeSparsity(nn.Module): + r"""Parametrization for the weights. Should be attached to the 'weight' or + any other parameter that requires a mask applied to it. + + Note:: + + Once the mask is passed, the variable should not change the id. The + contents of the mask can change, but the mask reference itself should + not. + """ + + def __init__(self, mask): + super().__init__() + self.register_buffer("mask", mask) + + def forward(self, x): + assert self.mask.shape == x.shape + return self.mask * x + + def state_dict(self, *args, **kwargs): + # We don't want to let the parametrizations to save the mask. + # That way we make sure that the linear module doesn't store the masks + # alongside their parametrizations. + return {} diff --git a/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py b/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py new file mode 100644 index 0000000000..2b24ca3d82 --- /dev/null +++ b/torchao/sparsity/prototype/sparsifier/weight_norm_sparsifier.py @@ -0,0 +1,200 @@ +from functools import reduce +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from .base_sparsifier import BaseSparsifier +import operator + +__all__ = ["WeightNormSparsifier"] + +def _flat_idx_to_2d(idx, shape): + rows = idx // shape[1] + cols = idx % shape[1] + return rows, cols + +class WeightNormSparsifier(BaseSparsifier): + r"""Weight-Norm Sparsifier + + This sparsifier computes the norm of every sparse block and "zeroes-out" the + ones with the lowest norm. The level of sparsity defines how many of the + blocks is removed. + + This sparsifier is controlled by three variables: + 1. `sparsity_level` defines the number of *sparse blocks* that are zeroed-out + 2. `sparse_block_shape` defines the shape of the sparse blocks. Note that + the sparse blocks originate at the zero-index of the tensor. + 3. `zeros_per_block` is the number of zeros that we are expecting in each + sparse block. By default we assume that all elements within a block are + zeroed-out. However, setting this variable sets the target number of + zeros per block. The zeros within each block are chosen as the *smallest + absolute values*. + + Args: + + sparsity_level: The target level of sparsity + sparse_block_shape: The shape of a sparse block (see note below) + zeros_per_block: Number of zeros in a sparse block + norm: Norm to use. Could be either `int` or a callable. + If `int`, only L1 and L2 are implemented. + + Note:: + The `sparse_block_shape` is tuple representing (block_ROWS, block_COLS), + irrespective of what the rows / cols mean in the data tensor. That means, + if you were to sparsify a weight tensor in the nn.Linear, which has a + weight shape `(Cout, Cin)`, the `block_ROWS` would refer to the output + channels, while the `block_COLS` would refer to the input channels. + + Note:: + All arguments to the WeightNormSparsifier constructor are "default" + arguments and could be overriden by the configuration provided in the + `prepare` step. + """ + def __init__(self, + sparsity_level: float = 0.5, + sparse_block_shape: Tuple[int, int] = (1, 4), + zeros_per_block: Optional[int] = None, + norm: Optional[Union[Callable, int]] = None): + if zeros_per_block is None: + zeros_per_block = reduce(operator.mul, sparse_block_shape) + defaults = { + "sparsity_level": sparsity_level, + "sparse_block_shape": sparse_block_shape, + "zeros_per_block": zeros_per_block, + } + if norm is None: + norm = 2 + if callable(norm): + self.norm_fn = norm + elif norm == 1: + self.norm_fn = lambda T: T.abs() + elif norm == 2: + self.norm_fn = lambda T: T * T + else: + raise NotImplementedError(f"L-{norm} is not yet implemented.") + super().__init__(defaults=defaults) + + def _scatter_fold_block_mask(self, output_shape, dim, indices, block_shape, + mask=None, input_shape=None, device=None): + r"""Creates patches of size `block_shape` after scattering the indices.""" + if mask is None: + assert input_shape is not None + mask = torch.ones(input_shape, device=device) + mask.scatter_(dim=dim, index=indices, value=0) + mask.data = F.fold(mask, output_size=output_shape, kernel_size=block_shape, stride=block_shape) + return mask + + def _make_tensor_mask(self, data, input_shape, sparsity_level, sparse_block_shape, mask=None): + r"""Creates a tensor-level mask. + + Tensor-level mask is described as a mask, where the granularity of sparsification of the + smallest patch is the sparse_block_shape. That means, that for a given mask and a + sparse_block_shape, the smallest "patch" of zeros/ones could be the sparse_block_shape. + + In this context, `sparsity_level` describes the fraction of sparse patches. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + + if mask is None: + mask = torch.ones(h + dh, w + dw, device=data.device) + + if sparsity_level >= 1.0: + mask.data = torch.zeros_like(mask) + return mask + elif sparsity_level <= 0.0: + mask.data = torch.ones_like(mask) + return mask + + values_per_block = reduce(operator.mul, sparse_block_shape) + if values_per_block > 1: + # Reduce the data + data = F.avg_pool2d( + data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape, ceil_mode=True + ) + data = data.flatten() + num_blocks = len(data) + + data = data.repeat(1, values_per_block, 1) + + threshold_idx = int(round(sparsity_level * num_blocks)) + threshold_idx = max(0, min(num_blocks - 1, threshold_idx)) # Sanity check + _, sorted_idx = torch.topk(data, k=threshold_idx, dim=2, largest=False) + + # Temp reshape for mask + mask_reshape = mask.reshape(data.shape) # data might be reshaped + self._scatter_fold_block_mask( + dim=2, output_shape=(h + dh, w + dw), + indices=sorted_idx, block_shape=sparse_block_shape, mask=mask_reshape + ) + mask.data = mask_reshape.squeeze().reshape(mask.shape)[:h, :w].contiguous() + return mask + + def _make_block_mask(self, data, sparse_block_shape, zeros_per_block, mask=None): + r"""Creates a block-level mask. + + Block-level mask is described as a mask, where the granularity of sparsification of the + largest patch is the sparse_block_shape. That means that for a given mask and a + sparse_block_shape, the sparsity is computed only within a patch of a size sparse_block_shape. + + In this context the `zeros_per_block` describes the number of zeroed-out elements within a patch. + """ + h, w = data.shape[-2:] + block_h, block_w = sparse_block_shape + dh = (block_h - h % block_h) % block_h + dw = (block_w - w % block_w) % block_w + values_per_block = reduce(operator.mul, sparse_block_shape) + + if mask is None: + mask = torch.ones((h + dh, w + dw), device=data.device) + + if values_per_block == zeros_per_block: + # Everything should be sparsified + mask.data = torch.zeros_like(mask) + return mask + + # create a new padded tensor like data (to match the block_shape) + padded_data = torch.ones(h + dh, w + dw, dtype=data.dtype, device=data.device) + padded_data.fill_(torch.nan) + padded_data[:h, :w] = data + unfolded_data = F.unfold(padded_data[None, None, :], kernel_size=sparse_block_shape, stride=sparse_block_shape) + + # Temp reshape for mask + mask_reshape = mask.reshape(unfolded_data.shape) + _, sorted_idx = torch.topk(unfolded_data, k=zeros_per_block, dim=1, largest=False) + + self._scatter_fold_block_mask( + dim=1, indices=sorted_idx, output_shape=padded_data.shape, block_shape=sparse_block_shape, mask=mask_reshape + ) + + mask.data = mask_reshape.squeeze().reshape(mask.shape).contiguous() + return mask + + def update_mask(self, module, tensor_name, sparsity_level, sparse_block_shape, + zeros_per_block, **kwargs): + values_per_block = reduce(operator.mul, sparse_block_shape) + if zeros_per_block > values_per_block: + raise ValueError( + "Number of zeros per block cannot be more than the total number of elements in that block." + ) + if zeros_per_block < 0: + raise ValueError("Number of zeros per block should be positive.") + + mask = getattr(module.parametrizations, tensor_name)[0].mask + if sparsity_level <= 0 or zeros_per_block == 0: + mask.data = torch.ones_like(mask) + elif sparsity_level >= 1.0 and (zeros_per_block == values_per_block): + mask.data = torch.zeros_like(mask) + else: + ww = self.norm_fn(getattr(module, tensor_name)) + tensor_mask = self._make_tensor_mask( + data=ww, input_shape=ww.shape, sparsity_level=sparsity_level, sparse_block_shape=sparse_block_shape + ) + if values_per_block != zeros_per_block: + block_mask = self._make_block_mask(data=ww, sparse_block_shape=sparse_block_shape, + zeros_per_block=zeros_per_block) + tensor_mask = torch.logical_or(tensor_mask, block_mask) + mask.data = tensor_mask