Skip to content

Commit

Permalink
add SpectralConvTranspose2d layer and update test for tf compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Franck Mamalet committed Oct 21, 2024
1 parent c2aca24 commit 9d514c1
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 52 deletions.
1 change: 1 addition & 0 deletions deel/torchlip/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .activation import MaxMin
from .conv import FrobeniusConv2d
from .conv import SpectralConv2d
from .conv import SpectralConvTranspose2d
from .downsampling import InvertibleDownSampling
from .linear import FrobeniusLinear
from .linear import SpectralLinear
Expand Down
74 changes: 74 additions & 0 deletions deel/torchlip/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,77 @@ def vanilla_export(self):
if self.bias is not None:
layer.bias.data = self.bias.detach()
return layer


class SpectralConvTranspose2d(torch.nn.ConvTranspose2d, LipschitzModule):
r"""Applies a 2D transposed convolution operator over an input image."""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: _size_2_t = 0,
output_padding: _size_2_t = 0,
groups: int = 1,
bias: bool = True,
dilation: _size_2_t = 1,
padding_mode: str = "zeros",
device=None,
dtype=None,
k_coef_lip: float = 1.0,
eps_spectral: int = DEFAULT_EPS_SPECTRAL,
eps_bjorck: int = DEFAULT_EPS_BJORCK,
) -> None:
if dilation != 1:
raise ValueError("SpectralConvTranspose2d does not support dilation rate")
if not output_padding in [0, None]:
raise ValueError("SpectralConvTranspose2d only supports output_padding=0")
torch.nn.ConvTranspose2d.__init__(
self,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
LipschitzModule.__init__(self, k_coef_lip)

torch.nn.init.orthogonal_(self.weight)
if self.bias is not None:
self.bias.data.fill_(0.0)

spectral_norm(
self,
name="weight",
eps=eps_spectral,
)
bjorck_norm(self, name="weight", eps=eps_bjorck)
lconv_norm(self, name="weight")
self.apply_lipschitz_factor()

def vanilla_export(self):
layer = torch.nn.ConvTranspose2d(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
output_padding=self.output_padding,
dilation=self.dilation,
groups=self.groups,
bias=self.bias is not None,
padding_mode=self.padding_mode,
)
layer.weight.data = self.weight.detach()
if self.bias is not None:
layer.bias.data = self.bias.detach()
return layer
55 changes: 30 additions & 25 deletions tests/test_condense.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,14 @@
from . import utils_framework as uft

from tests.utils_framework import (
vanillaModel,
vanilla_require_a_copy,
copy_model_parameters,
Sequential,
tModel,
)

from tests.utils_framework import (
SpectralLinear,
SpectralConv2d,
SpectralConv2dTranspose,
SpectralConvTranspose2d,
FrobeniusLinear,
FrobeniusConv2d,
ScaledL2NormPool2d,
Expand Down Expand Up @@ -75,7 +72,7 @@ def sequential_layers(input_shape):
{"in_channels": 2, "out_channels": 2, "kernel_size": (3, 3), "padding": 1},
),
uft.get_instance_framework(
SpectralConv2dTranspose,
SpectralConvTranspose2d,
{"in_channels": 2, "out_channels": 5, "kernel_size": (3, 3), "padding": 1},
),
uft.get_instance_framework(Flatten, {}),
Expand Down Expand Up @@ -120,7 +117,7 @@ def get_functional_tensors(input_shape):
},
)
dict_functional_tensors["convt2"] = uft.get_instance_framework(
SpectralConv2dTranspose,
SpectralConvTranspose2d,
{"in_channels": 2, "out_channels": 5, "kernel_size": (3, 3), "padding": 1},
)
dict_functional_tensors["flatten"] = uft.get_instance_framework(Flatten, {})
Expand Down Expand Up @@ -157,39 +154,42 @@ def functional_input_output_tensors(dict_functional_tensors, x):
# return x


def get_model(layer_type, layer_params, input_shape, k_coef_lip):
if layer_type == tModel:
def get_model(model_type, layer_params, input_shape, k_coef_lip):
if model_type == tModel:
return uft.get_functional_model(
tModel,
layer_params["dict_tensors"],
layer_params["functional_input_output_tensors"],
)
else:
return uft.generate_k_lip_model(
layer_type, layer_params, input_shape=input_shape, k=k_coef_lip
model_type, layer_params, input_shape=input_shape, k=k_coef_lip
)


@pytest.mark.skipif(
hasattr(SpectralConv2dTranspose, "unavailable_class"),
reason="SpectralConv2dTranspose not available",
hasattr(SpectralConvTranspose2d, "unavailable_class"),
reason="SpectralConvTranspose2d not available",
)
@pytest.mark.parametrize(
"layer_type, layer_params, k_coef_lip, input_shape",
"model_type, params_type, param_fct, dict_other_params, k_coef_lip, input_shape",
[
(Sequential, {"layers": sequential_layers((3, 8, 8))}, 5.0, (3, 8, 8)),
(Sequential, "layers", sequential_layers, {}, 5.0, (3, 8, 8)),
(
tModel,
"dict_tensors",
get_functional_tensors,
{
"dict_tensors": get_functional_tensors((3, 8, 8)),
"functional_input_output_tensors": functional_input_output_tensors,
},
5.0,
(3, 8, 8),
),
],
)
def test_model(layer_type, layer_params, k_coef_lip, input_shape):
def test_model(
model_type, params_type, param_fct, dict_other_params, k_coef_lip, input_shape
):
batch_size = 250
epochs = 1
steps_per_epoch = 125
Expand All @@ -198,9 +198,11 @@ def test_model(layer_type, layer_params, k_coef_lip, input_shape):
# clear session to avoid side effects from previous train
uft.init_session() # K.clear_session()
np.random.seed(42)
input_shape_CHW = input_shape
input_shape = uft.to_framework_channel(input_shape)

model = get_model(layer_type, layer_params, input_shape, k_coef_lip)
layer_params = {params_type: param_fct(input_shape_CHW)}
layer_params.update(dict_other_params)
model = get_model(model_type, layer_params, input_shape, k_coef_lip)

# create the keras model, defin opt, and compile it
optimizer = uft.get_instance_framework(
Expand Down Expand Up @@ -254,12 +256,14 @@ def test_model(layer_type, layer_params, k_coef_lip, input_shape):
# verbose=0,
# )
# generate vanilla
if vanilla_require_a_copy():
model2 = get_model(layer_type, layer_params, input_shape, k_coef_lip)
copy_model_parameters(model, model2)
vanilla_model = vanillaModel(model2)
if uft.vanilla_require_a_copy():
layer_params = {params_type: param_fct(input_shape_CHW)}
layer_params.update(dict_other_params)
model2 = get_model(model_type, layer_params, input_shape, k_coef_lip)
uft.copy_model_parameters(model, model2)
vanilla_model = uft.vanillaModel(model2)
else:
vanilla_model = vanillaModel(model)
vanilla_model = uft.vanillaModel(model)
# vanilla_model = model.vanilla_export()
loss_fn, optimizer, metrics = uft.compile_model(
vanilla_model,
Expand Down Expand Up @@ -290,12 +294,13 @@ def test_model(layer_type, layer_params, k_coef_lip, input_shape):
# steps=10,
# verbose=0,
# )
model.summary()
vanilla_model.summary()
# model.summary()
# vanilla_model.summary()

np.testing.assert_equal(
np.testing.assert_almost_equal(
mse,
vanilla_mse,
3,
"the exported vanilla model must have same behaviour as original",
)
np.testing.assert_equal(
Expand Down
58 changes: 32 additions & 26 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .utils_framework import (
SpectralLinear,
SpectralConv2d,
SpectralConv2dTranspose,
SpectralConvTranspose2d,
FrobeniusLinear,
FrobeniusConv2d,
ScaledAvgPool2d,
Expand Down Expand Up @@ -253,7 +253,6 @@ def train_k_lip_model(


def _check_mse_results(mse, from_disk_mse, test_params):
print("aaaaa", mse, from_disk_mse)
assert from_disk_mse == pytest.approx(
mse, 1e-5
), "serialization must not change the performance of a layer"
Expand Down Expand Up @@ -588,14 +587,14 @@ def test_spectralconv2d(test_params):


@pytest.mark.skipif(
hasattr(SpectralConv2dTranspose, "unavailable_class"),
reason="SpectralConv2dTranspose not available",
hasattr(SpectralConvTranspose2d, "unavailable_class"),
reason="SpectralConvTranspose2d not available",
)
@pytest.mark.parametrize(
"test_params",
[
dict(
layer_type=SpectralConv2dTranspose,
layer_type=SpectralConvTranspose2d,
layer_params={
"in_channels": 1,
"out_channels": 2,
Expand All @@ -611,7 +610,7 @@ def test_spectralconv2d(test_params):
callbacks=[],
),
dict(
layer_type=SpectralConv2dTranspose,
layer_type=SpectralConvTranspose2d,
layer_params={"in_channels": 1, "out_channels": 2, "kernel_size": (3, 3)},
batch_size=250,
steps_per_epoch=125,
Expand All @@ -622,7 +621,7 @@ def test_spectralconv2d(test_params):
callbacks=[],
),
dict(
layer_type=SpectralConv2dTranspose,
layer_type=SpectralConvTranspose2d,
layer_params={"in_channels": 1, "out_channels": 2, "kernel_size": (3, 3)},
batch_size=250,
steps_per_epoch=125,
Expand All @@ -634,7 +633,7 @@ def test_spectralconv2d(test_params):
),
],
)
def test_SpectralConv2dTranspose(test_params):
def test_SpectralConvTranspose2d(test_params):
_apply_tests_bank(test_params)


Expand Down Expand Up @@ -1165,15 +1164,15 @@ def test_invertibleupsampling(test_params):


@pytest.mark.skipif(
hasattr(SpectralConv2dTranspose, "unavailable_class"),
reason="SpectralConv2dTranspose not available",
hasattr(SpectralConvTranspose2d, "unavailable_class"),
reason="SpectralConvTranspose2d not available",
)
@pytest.mark.parametrize(
"test_params,msg",
[
(dict(in_channels=1, out_channels=5, kernel_size=3), ""),
(
dict(in_channels=1, out_channels=12, kernel_size=5, strides=2, bias=False),
dict(in_channels=1, out_channels=12, kernel_size=5, stride=2, bias=False),
"",
),
(
Expand All @@ -1182,7 +1181,7 @@ def test_invertibleupsampling(test_params):
out_channels=3,
kernel_size=3,
padding="same",
dilation_rate=1,
dilation=1,
),
"",
),
Expand Down Expand Up @@ -1216,7 +1215,7 @@ def test_invertibleupsampling(test_params):
"Wrong padding",
),
(
dict(in_channels=1, out_channels=10, kernel_size=3, dilation_rate=2),
dict(in_channels=1, out_channels=10, kernel_size=3, dilation=2),
"Wrong dilation rate",
),
(
Expand All @@ -1225,52 +1224,59 @@ def test_invertibleupsampling(test_params):
),
],
)
def test_SpectralConv2dTranspose_instantiation(test_params, msg):
def test_SpectralConvTranspose2d_instantiation(test_params, msg):
if msg == "":
uft.get_instance_framework(SpectralConv2dTranspose, test_params)
uft.get_instance_framework(SpectralConvTranspose2d, test_params)
else:
with pytest.raises(ValueError):
uft.get_instance_framework(SpectralConv2dTranspose, test_params)
uft.get_instance_framework(SpectralConvTranspose2d, test_params)


@pytest.mark.skipif(
hasattr(SpectralConv2dTranspose, "unavailable_class"),
reason="SpectralConv2dTranspose not available",
hasattr(SpectralConvTranspose2d, "unavailable_class"),
reason="SpectralConvTranspose2d not available",
)
def test_SpectralConv2dTranspose_vanilla_export():
def test_SpectralConvTranspose2d_vanilla_export():
kwargs = dict(
in_channels=3,
out_channels=16,
kernel_size=5,
strides=2,
stride=2,
activation="relu",
data_format="channels_first",
input_shape=(3, 28, 28),
)

model = uft.generate_k_lip_model(
SpectralConv2dTranspose, kwargs, kwargs["input_shape"], 1.0
SpectralConvTranspose2d, kwargs, kwargs["input_shape"], 1.0
)

# lay = SpectralConv2dTranspose(**kwargs)
# lay = SpectralConvTranspose2d(**kwargs)
# model = Sequential([lay])
x = np.random.normal(size=(5,) + kwargs["input_shape"])

x = uft.to_tensor(x)
y1 = model(x)

# Test vanilla export inference comparison
vanilla_model = model.vanilla_export()
if uft.vanilla_require_a_copy():
model2 = uft.generate_k_lip_model(
SpectralConvTranspose2d, kwargs, kwargs["input_shape"], 1.0
)
uft.copy_model_parameters(model, model2)
vanilla_model = uft.vanillaModel(model2)
else:
vanilla_model = uft.vanillaModel(model) # .vanilla_export()
y2 = vanilla_model(x)
np.testing.assert_allclose(y1, y2, atol=1e-6)
np.testing.assert_allclose(uft.to_numpy(y1), uft.to_numpy(y2), atol=1e-6)

# Test saving/loading model
with tempfile.TemporaryDirectory() as tmpdir:
uft.MODEL_PATH = os.path.join(tmpdir, uft.MODEL_PATH)
model.save(uft.MODEL_PATH)
uft.save_model(model, uft.MODEL_PATH, overwrite=True)
uft.load_model(
uft.MODEL_PATH,
layer_type=SpectralConv2dTranspose,
layer_type=SpectralConvTranspose2d,
layer_params=kwargs,
input_shape=kwargs["input_shape"],
k=1.0,
Expand Down
Loading

0 comments on commit 9d514c1

Please sign in to comment.