Skip to content

Commit

Permalink
port over torch.ao.pruning to protype folder (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaip authored Apr 24, 2024
1 parent f05c215 commit af048aa
Show file tree
Hide file tree
Showing 30 changed files with 4,511 additions and 0 deletions.
175 changes: 175 additions & 0 deletions test/sparsity/test_parametrization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import logging
import torch
import unittest
from torch import nn
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity.prototype.sparsifier import utils

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


class ModelUnderTest(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.linear = nn.Linear(16, 16, bias=bias)
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=bias), nn.Linear(16, 16, bias=bias)
)

# Make sure the weights are not random
self.linear.weight = nn.Parameter(torch.zeros_like(self.linear.weight) + 1.0)
self.seq[0].weight = nn.Parameter(torch.zeros_like(self.seq[0].weight) + 2.0)
self.seq[1].weight = nn.Parameter(torch.zeros_like(self.seq[1].weight) + 3.0)
if bias:
self.linear = nn.Parameter(torch.zeros_like(self.linear.bias) + 10.0)
self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 20.0)
self.seq[0] = nn.Parameter(torch.zeros_like(self.seq[0].bias) + 30.0)

def forward(self, x):
x = self.linear(x)
x = self.seq(x)
return x


class TestFakeSparsity(TestCase):
def test_masking_logic(self):
model = nn.Linear(16, 16, bias=False)
model.weight = nn.Parameter(torch.eye(16))
x = torch.randn(3, 16)
self.assertEqual(torch.mm(x, torch.eye(16)), model(x))

mask = torch.zeros(16, 16)
sparsity = utils.FakeSparsity(mask)
parametrize.register_parametrization(model, "weight", sparsity)

x = torch.randn(3, 16)
self.assertEqual(torch.zeros(3, 16), model(x))

def test_weights_parametrized(self):
model = ModelUnderTest(bias=False)

assert not hasattr(model.linear, "parametrizations")
assert not hasattr(model.seq[0], "parametrizations")
assert not hasattr(model.seq[1], "parametrizations")
mask = torch.eye(16)
parametrize.register_parametrization(
model.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(
model.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(
model.seq[1], "weight", utils.FakeSparsity(mask)
)

assert hasattr(model.linear, "parametrizations")
assert parametrize.is_parametrized(model.linear, "weight")
assert hasattr(model.seq[0], "parametrizations")
assert parametrize.is_parametrized(model.linear, "weight")
assert hasattr(model.seq[1], "parametrizations")
assert parametrize.is_parametrized(model.linear, "weight")

def test_state_dict_preserved(self):
model_save = ModelUnderTest(bias=False)

mask = torch.eye(16)
parametrize.register_parametrization(
model_save.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(
model_save.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(
model_save.seq[1], "weight", utils.FakeSparsity(mask)
)
state_dict = model_save.state_dict()

model_load = ModelUnderTest(bias=False)
mask = torch.zeros(model_load.linear.weight.shape)
parametrize.register_parametrization(
model_load.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.zeros(model_load.seq[0].weight.shape)
parametrize.register_parametrization(
model_load.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.zeros(model_load.seq[1].weight.shape)
parametrize.register_parametrization(
model_load.seq[1], "weight", utils.FakeSparsity(mask)
)
# Keep this strict, as we are not loading the 'mask'
model_load.load_state_dict(state_dict, strict=False)

# Check the parametrizations are preserved
assert hasattr(model_load.linear, "parametrizations")
assert parametrize.is_parametrized(model_load.linear, "weight")
assert hasattr(model_load.seq[0], "parametrizations")
assert parametrize.is_parametrized(model_load.linear, "weight")
assert hasattr(model_load.seq[1], "parametrizations")
assert parametrize.is_parametrized(model_load.linear, "weight")

# Check the weights are preserved
self.assertEqual(
model_save.linear.parametrizations["weight"].original,
model_load.linear.parametrizations["weight"].original,
)
self.assertEqual(
model_save.seq[0].parametrizations["weight"].original,
model_load.seq[0].parametrizations["weight"].original,
)
self.assertEqual(
model_save.seq[1].parametrizations["weight"].original,
model_load.seq[1].parametrizations["weight"].original,
)

# Check the masks are not preserved in the state_dict
# We store the state_dicts in the sparsifier, not in the model itself.
# TODO: Need to find a clean way of exporting the parametrized model
self.assertNotEqual(
model_save.linear.parametrizations["weight"][0].mask,
model_load.linear.parametrizations["weight"][0].mask,
)
self.assertNotEqual(
model_save.seq[0].parametrizations["weight"][0].mask,
model_load.seq[0].parametrizations["weight"][0].mask,
)
self.assertNotEqual(
model_save.seq[1].parametrizations["weight"][0].mask,
model_load.seq[1].parametrizations["weight"][0].mask,
)

def test_jit_trace(self):
model = ModelUnderTest(bias=False)

mask = torch.eye(16)
parametrize.register_parametrization(
model.linear, "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(
model.seq[0], "weight", utils.FakeSparsity(mask)
)
mask = torch.eye(16)
parametrize.register_parametrization(
model.seq[1], "weight", utils.FakeSparsity(mask)
)

# Tracing
example_x = torch.ones(3, 16)
model_trace = torch.jit.trace_module(model, {"forward": example_x})

x = torch.randn(3, 16)
y = model(x)
y_hat = model_trace(x)
self.assertEqual(y_hat, y)

if __name__ == "__main__":
unittest.main()
194 changes: 194 additions & 0 deletions test/sparsity/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import warnings
import unittest

from torch import nn
from torch.testing._internal.common_utils import TestCase

from torchao.sparsity.prototype import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier

class ImplementedScheduler(BaseScheduler):
def get_sl(self):
if self.last_epoch > 0:
return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups]
else:
return list(self.base_sl)


class TestScheduler(TestCase):
def test_constructor(self):
model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)

assert scheduler.sparsifier is sparsifier
assert scheduler._step_count == 1
assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]]

def test_order_of_steps(self):
"""Checks if the warning is thrown if the scheduler step is called
before the sparsifier step"""

model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
scheduler = ImplementedScheduler(sparsifier)

# Sparsifier step is not called
with self.assertWarns(UserWarning):
scheduler.step()

# Correct order has no warnings
# Note: This will trigger if other warnings are present.
with warnings.catch_warnings(record=True) as w:
sparsifier.step()
scheduler.step()
# Make sure there is no warning related to the base_scheduler
for warning in w:
fname = warning.filename
fname = "/".join(fname.split("/")[-5:])
assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py"

def test_step(self):
model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
assert sparsifier.groups[0]["sparsity_level"] == 0.5
scheduler = ImplementedScheduler(sparsifier)
assert sparsifier.groups[0]["sparsity_level"] == 0.5

sparsifier.step()
scheduler.step()
assert sparsifier.groups[0]["sparsity_level"] == 0.25

def test_lambda_scheduler(self):
model = nn.Sequential(nn.Linear(16, 16))
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
assert sparsifier.groups[0]["sparsity_level"] == 0.5
scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
assert sparsifier.groups[0]["sparsity_level"] == 0.0 # Epoch 0
scheduler.step()
assert sparsifier.groups[0]["sparsity_level"] == 5.0 # Epoch 1


class TestCubicScheduler(TestCase):
def setUp(self):
self.model_sparse_config = [
{"tensor_fqn": "0.weight", "sparsity_level": 0.8},
{"tensor_fqn": "2.weight", "sparsity_level": 0.4},
]
self.sorted_sparse_levels = [
conf["sparsity_level"] for conf in self.model_sparse_config
]
self.initial_sparsity = 0.1
self.initial_step = 3

def _make_model(self, **kwargs):
model = nn.Sequential(
nn.Linear(13, 17),
nn.Dropout(0.5),
nn.Linear(17, 3),
)
return model

def _make_scheduler(self, model, **kwargs):
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=self.model_sparse_config)

scheduler_args = {
"init_sl": self.initial_sparsity,
"init_t": self.initial_step,
}
scheduler_args.update(kwargs)

scheduler = CubicSL(sparsifier, **scheduler_args)
return sparsifier, scheduler

@staticmethod
def _get_sparsity_levels(sparsifier, precision=32):
r"""Gets the current levels of sparsity in a sparsifier."""
return [
round(group["sparsity_level"], precision) for group in sparsifier.groups
]

def test_constructor(self):
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
self.assertIs(
scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached"
)
self.assertEqual(
scheduler._step_count,
1,
msg="Scheduler is initialized with incorrect step count",
)
self.assertEqual(
scheduler.base_sl,
self.sorted_sparse_levels,
msg="Scheduler did not store the target sparsity levels correctly",
)

# Value before t_0 is 0
self.assertEqual(
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(0.0),
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
)

# Value before t_0 is s_0
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False)
self.assertEqual(
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(self.initial_sparsity),
msg="Sparsifier is not reset correctly after attaching to the Scheduler",
)

def test_step(self):
# For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
model = self._make_model()
sparsifier, scheduler = self._make_scheduler(
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5
)

scheduler.step()
scheduler.step()
self.assertEqual(
scheduler._step_count,
3,
msg="Scheduler step_count is expected to increment",
)
# Value before t_0 is supposed to be 0
self.assertEqual(
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(0.0),
msg="Scheduler step updating the sparsity level before t_0",
)

scheduler.step() # Step = 3 => sparsity = initial_sparsity
self.assertEqual(
self._get_sparsity_levels(sparsifier),
scheduler._make_sure_a_list(self.initial_sparsity),
msg="Sparsifier is not reset to initial sparsity at the first step",
)

scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2]
self.assertEqual(
self._get_sparsity_levels(sparsifier, 1),
[0.3, 0.2],
msg="Sparsity level is not set correctly after the first step",
)

current_step = scheduler._step_count - scheduler.init_t[0] - 1
more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
for _ in range(more_steps_needed): # More steps needed to final sparsity level
scheduler.step()
self.assertEqual(
self._get_sparsity_levels(sparsifier),
self.sorted_sparse_levels,
msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
)

if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit af048aa

Please sign in to comment.