-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
port over torch.ao.pruning to protype folder (#148)
- Loading branch information
Showing
30 changed files
with
4,511 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import logging | ||
import torch | ||
import unittest | ||
from torch import nn | ||
from torch.nn.utils import parametrize | ||
from torch.testing._internal.common_utils import TestCase | ||
|
||
from torchao.sparsity.prototype.sparsifier import utils | ||
|
||
logging.basicConfig( | ||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | ||
) | ||
|
||
|
||
class ModelUnderTest(nn.Module): | ||
def __init__(self, bias=True): | ||
super().__init__() | ||
self.linear = nn.Linear(16, 16, bias=bias) | ||
self.seq = nn.Sequential( | ||
nn.Linear(16, 16, bias=bias), nn.Linear(16, 16, bias=bias) | ||
) | ||
|
||
# Make sure the weights are not random | ||
self.linear.weight = nn.Parameter(torch.zeros_like(self.linear.weight) + 1.0) | ||
self.seq[0].weight = nn.Parameter(torch.zeros_like(self.seq[0].weight) + 2.0) | ||
self.seq[1].weight = nn.Parameter(torch.zeros_like(self.seq[1].weight) + 3.0) | ||
if bias: | ||
self.linear = nn.Parameter(torch.zeros_like(self.linear.bias) + 10.0) | ||
self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 20.0) | ||
self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 30.0) | ||
|
||
def forward(self, x): | ||
x = self.linear(x) | ||
x = self.seq(x) | ||
return x | ||
|
||
|
||
class TestFakeSparsity(TestCase): | ||
def test_masking_logic(self): | ||
model = nn.Linear(16, 16, bias=False) | ||
model.weight = nn.Parameter(torch.eye(16)) | ||
x = torch.randn(3, 16) | ||
self.assertEqual(torch.mm(x, torch.eye(16)), model(x)) | ||
|
||
mask = torch.zeros(16, 16) | ||
sparsity = utils.FakeSparsity(mask) | ||
parametrize.register_parametrization(model, "weight", sparsity) | ||
|
||
x = torch.randn(3, 16) | ||
self.assertEqual(torch.zeros(3, 16), model(x)) | ||
|
||
def test_weights_parametrized(self): | ||
model = ModelUnderTest(bias=False) | ||
|
||
assert not hasattr(model.linear, "parametrizations") | ||
assert not hasattr(model.seq[0], "parametrizations") | ||
assert not hasattr(model.seq[1], "parametrizations") | ||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model.linear, "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model.seq[0], "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model.seq[1], "weight", utils.FakeSparsity(mask) | ||
) | ||
|
||
assert hasattr(model.linear, "parametrizations") | ||
assert parametrize.is_parametrized(model.linear, "weight") | ||
assert hasattr(model.seq[0], "parametrizations") | ||
assert parametrize.is_parametrized(model.linear, "weight") | ||
assert hasattr(model.seq[1], "parametrizations") | ||
assert parametrize.is_parametrized(model.linear, "weight") | ||
|
||
def test_state_dict_preserved(self): | ||
model_save = ModelUnderTest(bias=False) | ||
|
||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model_save.linear, "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model_save.seq[0], "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model_save.seq[1], "weight", utils.FakeSparsity(mask) | ||
) | ||
state_dict = model_save.state_dict() | ||
|
||
model_load = ModelUnderTest(bias=False) | ||
mask = torch.zeros(model_load.linear.weight.shape) | ||
parametrize.register_parametrization( | ||
model_load.linear, "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.zeros(model_load.seq[0].weight.shape) | ||
parametrize.register_parametrization( | ||
model_load.seq[0], "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.zeros(model_load.seq[1].weight.shape) | ||
parametrize.register_parametrization( | ||
model_load.seq[1], "weight", utils.FakeSparsity(mask) | ||
) | ||
# Keep this strict, as we are not loading the 'mask' | ||
model_load.load_state_dict(state_dict, strict=False) | ||
|
||
# Check the parametrizations are preserved | ||
assert hasattr(model_load.linear, "parametrizations") | ||
assert parametrize.is_parametrized(model_load.linear, "weight") | ||
assert hasattr(model_load.seq[0], "parametrizations") | ||
assert parametrize.is_parametrized(model_load.linear, "weight") | ||
assert hasattr(model_load.seq[1], "parametrizations") | ||
assert parametrize.is_parametrized(model_load.linear, "weight") | ||
|
||
# Check the weights are preserved | ||
self.assertEqual( | ||
model_save.linear.parametrizations["weight"].original, | ||
model_load.linear.parametrizations["weight"].original, | ||
) | ||
self.assertEqual( | ||
model_save.seq[0].parametrizations["weight"].original, | ||
model_load.seq[0].parametrizations["weight"].original, | ||
) | ||
self.assertEqual( | ||
model_save.seq[1].parametrizations["weight"].original, | ||
model_load.seq[1].parametrizations["weight"].original, | ||
) | ||
|
||
# Check the masks are not preserved in the state_dict | ||
# We store the state_dicts in the sparsifier, not in the model itself. | ||
# TODO: Need to find a clean way of exporting the parametrized model | ||
self.assertNotEqual( | ||
model_save.linear.parametrizations["weight"][0].mask, | ||
model_load.linear.parametrizations["weight"][0].mask, | ||
) | ||
self.assertNotEqual( | ||
model_save.seq[0].parametrizations["weight"][0].mask, | ||
model_load.seq[0].parametrizations["weight"][0].mask, | ||
) | ||
self.assertNotEqual( | ||
model_save.seq[1].parametrizations["weight"][0].mask, | ||
model_load.seq[1].parametrizations["weight"][0].mask, | ||
) | ||
|
||
def test_jit_trace(self): | ||
model = ModelUnderTest(bias=False) | ||
|
||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model.linear, "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model.seq[0], "weight", utils.FakeSparsity(mask) | ||
) | ||
mask = torch.eye(16) | ||
parametrize.register_parametrization( | ||
model.seq[1], "weight", utils.FakeSparsity(mask) | ||
) | ||
|
||
# Tracing | ||
example_x = torch.ones(3, 16) | ||
model_trace = torch.jit.trace_module(model, {"forward": example_x}) | ||
|
||
x = torch.randn(3, 16) | ||
y = model(x) | ||
y_hat = model_trace(x) | ||
self.assertEqual(y_hat, y) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import warnings | ||
import unittest | ||
|
||
from torch import nn | ||
from torch.testing._internal.common_utils import TestCase | ||
|
||
from torchao.sparsity.prototype import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier | ||
|
||
class ImplementedScheduler(BaseScheduler): | ||
def get_sl(self): | ||
if self.last_epoch > 0: | ||
return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups] | ||
else: | ||
return list(self.base_sl) | ||
|
||
|
||
class TestScheduler(TestCase): | ||
def test_constructor(self): | ||
model = nn.Sequential(nn.Linear(16, 16)) | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
scheduler = ImplementedScheduler(sparsifier) | ||
|
||
assert scheduler.sparsifier is sparsifier | ||
assert scheduler._step_count == 1 | ||
assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]] | ||
|
||
def test_order_of_steps(self): | ||
"""Checks if the warning is thrown if the scheduler step is called | ||
before the sparsifier step""" | ||
|
||
model = nn.Sequential(nn.Linear(16, 16)) | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
scheduler = ImplementedScheduler(sparsifier) | ||
|
||
# Sparsifier step is not called | ||
with self.assertWarns(UserWarning): | ||
scheduler.step() | ||
|
||
# Correct order has no warnings | ||
# Note: This will trigger if other warnings are present. | ||
with warnings.catch_warnings(record=True) as w: | ||
sparsifier.step() | ||
scheduler.step() | ||
# Make sure there is no warning related to the base_scheduler | ||
for warning in w: | ||
fname = warning.filename | ||
fname = "/".join(fname.split("/")[-5:]) | ||
assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py" | ||
|
||
def test_step(self): | ||
model = nn.Sequential(nn.Linear(16, 16)) | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
assert sparsifier.groups[0]["sparsity_level"] == 0.5 | ||
scheduler = ImplementedScheduler(sparsifier) | ||
assert sparsifier.groups[0]["sparsity_level"] == 0.5 | ||
|
||
sparsifier.step() | ||
scheduler.step() | ||
assert sparsifier.groups[0]["sparsity_level"] == 0.25 | ||
|
||
def test_lambda_scheduler(self): | ||
model = nn.Sequential(nn.Linear(16, 16)) | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=None) | ||
assert sparsifier.groups[0]["sparsity_level"] == 0.5 | ||
scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10) | ||
assert sparsifier.groups[0]["sparsity_level"] == 0.0 # Epoch 0 | ||
scheduler.step() | ||
assert sparsifier.groups[0]["sparsity_level"] == 5.0 # Epoch 1 | ||
|
||
|
||
class TestCubicScheduler(TestCase): | ||
def setUp(self): | ||
self.model_sparse_config = [ | ||
{"tensor_fqn": "0.weight", "sparsity_level": 0.8}, | ||
{"tensor_fqn": "2.weight", "sparsity_level": 0.4}, | ||
] | ||
self.sorted_sparse_levels = [ | ||
conf["sparsity_level"] for conf in self.model_sparse_config | ||
] | ||
self.initial_sparsity = 0.1 | ||
self.initial_step = 3 | ||
|
||
def _make_model(self, **kwargs): | ||
model = nn.Sequential( | ||
nn.Linear(13, 17), | ||
nn.Dropout(0.5), | ||
nn.Linear(17, 3), | ||
) | ||
return model | ||
|
||
def _make_scheduler(self, model, **kwargs): | ||
sparsifier = WeightNormSparsifier() | ||
sparsifier.prepare(model, config=self.model_sparse_config) | ||
|
||
scheduler_args = { | ||
"init_sl": self.initial_sparsity, | ||
"init_t": self.initial_step, | ||
} | ||
scheduler_args.update(kwargs) | ||
|
||
scheduler = CubicSL(sparsifier, **scheduler_args) | ||
return sparsifier, scheduler | ||
|
||
@staticmethod | ||
def _get_sparsity_levels(sparsifier, precision=32): | ||
r"""Gets the current levels of sparsity in a sparsifier.""" | ||
return [ | ||
round(group["sparsity_level"], precision) for group in sparsifier.groups | ||
] | ||
|
||
def test_constructor(self): | ||
model = self._make_model() | ||
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True) | ||
self.assertIs( | ||
scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached" | ||
) | ||
self.assertEqual( | ||
scheduler._step_count, | ||
1, | ||
msg="Scheduler is initialized with incorrect step count", | ||
) | ||
self.assertEqual( | ||
scheduler.base_sl, | ||
self.sorted_sparse_levels, | ||
msg="Scheduler did not store the target sparsity levels correctly", | ||
) | ||
|
||
# Value before t_0 is 0 | ||
self.assertEqual( | ||
self._get_sparsity_levels(sparsifier), | ||
scheduler._make_sure_a_list(0.0), | ||
msg="Sparsifier is not reset correctly after attaching to the Scheduler", | ||
) | ||
|
||
# Value before t_0 is s_0 | ||
model = self._make_model() | ||
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False) | ||
self.assertEqual( | ||
self._get_sparsity_levels(sparsifier), | ||
scheduler._make_sure_a_list(self.initial_sparsity), | ||
msg="Sparsifier is not reset correctly after attaching to the Scheduler", | ||
) | ||
|
||
def test_step(self): | ||
# For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0 | ||
model = self._make_model() | ||
sparsifier, scheduler = self._make_scheduler( | ||
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5 | ||
) | ||
|
||
scheduler.step() | ||
scheduler.step() | ||
self.assertEqual( | ||
scheduler._step_count, | ||
3, | ||
msg="Scheduler step_count is expected to increment", | ||
) | ||
# Value before t_0 is supposed to be 0 | ||
self.assertEqual( | ||
self._get_sparsity_levels(sparsifier), | ||
scheduler._make_sure_a_list(0.0), | ||
msg="Scheduler step updating the sparsity level before t_0", | ||
) | ||
|
||
scheduler.step() # Step = 3 => sparsity = initial_sparsity | ||
self.assertEqual( | ||
self._get_sparsity_levels(sparsifier), | ||
scheduler._make_sure_a_list(self.initial_sparsity), | ||
msg="Sparsifier is not reset to initial sparsity at the first step", | ||
) | ||
|
||
scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2] | ||
self.assertEqual( | ||
self._get_sparsity_levels(sparsifier, 1), | ||
[0.3, 0.2], | ||
msg="Sparsity level is not set correctly after the first step", | ||
) | ||
|
||
current_step = scheduler._step_count - scheduler.init_t[0] - 1 | ||
more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step | ||
for _ in range(more_steps_needed): # More steps needed to final sparsity level | ||
scheduler.step() | ||
self.assertEqual( | ||
self._get_sparsity_levels(sparsifier), | ||
self.sorted_sparse_levels, | ||
msg="Sparsity level is not reaching the target level afer delta_t * n steps ", | ||
) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.