diff --git a/deel/torchlip/modules/__init__.py b/deel/torchlip/modules/__init__.py index c612564..bcc2e36 100644 --- a/deel/torchlip/modules/__init__.py +++ b/deel/torchlip/modules/__init__.py @@ -77,3 +77,5 @@ from .pooling import ScaledL2NormPool2d from .pooling import ScaledGlobalL2NormPool2d from .upsampling import InvertibleUpSampling +from .normalization import LayerCentering +from .normalization import BatchCentering diff --git a/deel/torchlip/modules/normalization.py b/deel/torchlip/modules/normalization.py index 5deeb13..ddc7d2f 100644 --- a/deel/torchlip/modules/normalization.py +++ b/deel/torchlip/modules/normalization.py @@ -1,49 +1,60 @@ +from typing import Optional import torch import torch.nn as nn import torch.distributed as dist + class LayerCentering(nn.Module): - def __init__(self,size = -1, dim=[-2,-1],bias = True): + def __init__(self, size: int = 1, dim: tuple = [-2, -1], bias=True): super(LayerCentering, self).__init__() - self.bias = bias - if isinstance(size, tuple): - self.alpha = nn.Parameter(torch.zeros(size), requires_grad=True) + if bias: + self.bias = nn.Parameter(torch.zeros((size,)), requires_grad=True) else: - self.alpha = nn.Parameter(torch.zeros(1,size,1,1), requires_grad=True) + self.register_parameter("bias", None) self.dim = dim def forward(self, x): mean = x.mean(dim=self.dim, keepdim=True) - if self.bias: - return x - mean+ self.alpha - return x - mean - + if self.bias is not None: + bias_shape = (1, -1) + (1,) * (len(x.shape) - 2) + return x - mean + self.bias.view(bias_shape) + else: + return x - mean + -class LayerCentering2D(LayerCentering): - def __init__(self, size = 1, dim=[-2,-1]): - super(LayerCentering2D, self).__init__(size = size,dim=[-2,-1]) +LayerCentering2d = LayerCentering +# class LayerCentering2D(LayerCentering): +# def __init__(self, size = 1, dim=[-2,-1]): +# super(LayerCentering2D, self).__init__(size = size,dim=[-2,-1]) class BatchCentering(nn.Module): - def __init__(self, size =1, dim=[0,-2,-1], momentum=0.05): + def __init__( + self, + size: int = 1, + dim: Optional[tuple] = None, + momentum: float = 0.05, + bias: bool = True, + ): super(BatchCentering, self).__init__() self.dim = dim self.momentum = momentum - if isinstance(size, tuple): - self.register_buffer("running_mean", torch.zeros(size)) + self.register_buffer("running_mean", torch.zeros((size,))) + if bias: + self.bias = nn.Parameter(torch.zeros((size,)), requires_grad=True) else: - self.register_buffer("running_mean", torch.zeros(1,size,1,1)) + self.register_parameter("bias", None) self.first = True def forward(self, x): - + if self.dim is None: # (0,2,3) for 4D tensor; (0,) for 2D tensor + self.dim = (0,) + tuple(range(2, len(x.shape))) + mean_shape = (1, -1) + (1,) * (len(x.shape) - 2) if self.training: - mean = x.mean(dim=self.dim, keepdim=True) - #print(mean.shape) + mean = x.mean(dim=self.dim) with torch.no_grad(): if self.first: - #print("first") self.running_mean = mean self.first = False else: @@ -53,28 +64,33 @@ def forward(self, x): if dist.is_initialized(): dist.all_reduce(self.running_mean, op=dist.ReduceOp.SUM) self.running_mean /= dist.get_world_size() - - else : + else: mean = self.running_mean - return x - mean - -class BatchCenteringBiases(BatchCentering): - def __init__(self, size =1, dim=[0,-2,-1], momentum=0.05): - super(BatchCenteringBiases, self).__init__(size = size, dim = dim, momentum = momentum) - if isinstance(size, tuple): - self.alpha = nn.Parameter(torch.zeros(size), requires_grad=True) + if self.bias is not None: + return x - mean.view(mean_shape) + self.bias.view(mean_shape) else: - self.alpha = nn.Parameter(torch.zeros(1,size,1,1), requires_grad=True) + return x - mean.view(mean_shape) - def forward(self, x): - #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(), self.running_mean.abs().cpu().mean().numpy(), self.alpha.abs().mean().cpu().numpy()) - #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(),(x.mean(dim=self.dim, keepdim=True)-self.running_mean).abs().mean().cpu().numpy()) - return super().forward(x) + self.alpha -class BatchCenteringBiases2D(BatchCenteringBiases): - def __init__(self, size =1, momentum=0.05): - super(BatchCenteringBiases2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum) +# class BatchCenteringBiases(BatchCentering): +# def __init__(self, size =1, dim=[0,-2,-1], momentum=0.05): +# super(BatchCenteringBiases, self).__init__(size = size, dim = dim, momentum = momentum) +# if isinstance(size, tuple): +# self.alpha = nn.Parameter(torch.zeros(size), requires_grad=True) +# else: +# self.alpha = nn.Parameter(torch.zeros(1,size,1,1), requires_grad=True) + +# def forward(self, x): +# #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(), self.running_mean.abs().cpu().mean().numpy(), self.alpha.abs().mean().cpu().numpy()) +# #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(),(x.mean(dim=self.dim, keepdim=True)-self.running_mean).abs().mean().cpu().numpy()) +# return super().forward(x) + self.alpha + +BatchCentering2d = BatchCentering + +# class BatchCenteringBiases2D(BatchCenteringBiases): +# def __init__(self, size =1, momentum=0.05): +# super(BatchCenteringBiases2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum) -class BatchCentering2D(BatchCentering): - def __init__(self, size =1, momentum=0.05): - super(BatchCentering2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum) +# class BatchCentering2D(BatchCentering): +# def __init__(self, size =1, momentum=0.05): +# super(BatchCentering2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum) diff --git a/tests/test_normalization.py b/tests/test_normalization.py new file mode 100644 index 0000000..382180c --- /dev/null +++ b/tests/test_normalization.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- +# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All +# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, +# CRIAQ and ANITI - https://www.deel.ai/ +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All +# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, +# CRIAQ and ANITI - https://www.deel.ai/ +# ===================================================================================== +import os +import pytest +from functools import partial + +import numpy as np + +from . import utils_framework as uft + +from .utils_framework import BatchCentering, LayerCentering + + +def check_serialization(layer_type, layer_params, input_shape=(10,)): + m = uft.generate_k_lip_model(layer_type, layer_params, input_shape=input_shape, k=1) + if m is None: + pytest.skip() + loss, optimizer, _ = uft.compile_model( + m, + optimizer=uft.get_instance_framework(uft.SGD, inst_params={"model": m}), + loss=uft.CategoricalCrossentropy(from_logits=True), + ) + name = layer_type.__class__.__name__ + path = os.path.join("logs", "normalization", name) + xnp = np.random.uniform(-10, 10, (255,) + input_shape) + x = uft.to_tensor(xnp) + y1 = m(x) + uft.save_model(m, path) + m2 = uft.load_model( + path, + compile=True, + layer_type=layer_type, + layer_params=layer_params, + input_shape=input_shape, + k=1, + ) + y2 = m2(x) + np.testing.assert_allclose(uft.to_numpy(y1), uft.to_numpy(y2)) + + +@pytest.mark.parametrize( + "size, input_shape, bias", + [ + (4, (3, 4, 8, 8), False), + (4, (3, 4, 8, 8), True), + ], +) +def test_LayerCentering(size, input_shape, bias): + """evaluate layerbatch centering""" + input_shape = uft.to_framework_channel(input_shape) + x = np.arange(np.prod(input_shape)).reshape(input_shape) + bn = uft.get_instance_framework(LayerCentering, {"size": size, "bias": bias}) + + mean_x = np.mean(x, axis=(2, 3)) + mean_shape = (-1, size, 1, 1) + x = uft.to_tensor(x) + y = bn(x) + np.testing.assert_allclose( + uft.to_numpy(y), x - np.reshape(mean_x, mean_shape), atol=1e-5 + ) + y = bn(2 * x) + np.testing.assert_allclose( + uft.to_numpy(y), 2 * x - 2 * np.reshape(mean_x, mean_shape), atol=1e-5 + ) # keep substract batch mean + bn.eval() + y = bn(2 * x) + np.testing.assert_allclose( + uft.to_numpy(y), 2 * x - 2 * np.reshape(mean_x, mean_shape), atol=1e-5 + ) # eval mode use running_mean + + +@pytest.mark.parametrize( + "size, input_shape, bias", + [ + (4, (3, 4), False), + (4, (3, 4), True), + (4, (3, 4, 8, 8), False), + (4, (3, 4, 8, 8), True), + ], +) +def test_BatchCentering(size, input_shape, bias): + """evaluate layerbatch centering""" + input_shape = uft.to_framework_channel(input_shape) + x = np.arange(np.prod(input_shape)).reshape(input_shape) + bn = uft.get_instance_framework(BatchCentering, {"size": size, "bias": bias}) + bn_mom = bn.momentum + if len(input_shape) == 2: + mean_x = np.mean(x, axis=0) + mean_shape = (1, size) + else: + mean_x = np.mean(x, axis=(0, 2, 3)) + mean_shape = (1, size, 1, 1) + x = uft.to_tensor(x) + y = bn(x) + np.testing.assert_allclose(bn.running_mean, mean_x, atol=1e-5) + np.testing.assert_allclose( + uft.to_numpy(y), x - np.reshape(mean_x, mean_shape), atol=1e-5 + ) + y = bn(2 * x) + new_runningmean = mean_x * (1 - bn_mom) + 2 * mean_x * bn_mom + np.testing.assert_allclose(bn.running_mean, new_runningmean, atol=1e-5) + np.testing.assert_allclose( + uft.to_numpy(y), 2 * x - 2 * np.reshape(mean_x, mean_shape), atol=1e-5 + ) # keep substract batch mean + bn.eval() + y = bn(2 * x) + np.testing.assert_allclose( + bn.running_mean, new_runningmean, atol=1e-5 + ) # eval mode running mean freezed + np.testing.assert_allclose( + uft.to_numpy(y), 2 * x - np.reshape(new_runningmean, mean_shape), atol=1e-5 + ) # eval mode use running_mean + + +@pytest.mark.parametrize( + "norm_type", + [LayerCentering, BatchCentering], +) +@pytest.mark.parametrize( + "size, input_shape, bias", + [ + (10, (10,), False), + (10, (10,), True), + (7, (7, 8, 8), False), + (7, (7, 8, 8), True), + ], +) +def test_Normalization_serialization(norm_type, size, input_shape, bias): + # Check serialization + check_serialization( + norm_type, layer_params={"size": size, "bias": bias}, input_shape=input_shape + ) + + +def linear_generator(batch_size, input_shape: tuple): + """ + Generate data according to a linear kernel + Args: + batch_size: size of each batch + input_shape: shape of the desired input + + Returns: + a generator for the data + + """ + input_shape = tuple(input_shape) + while True: + # pick random sample in [0, 1] with the input shape + batch_x = np.array( + np.random.uniform(-10, 10, (batch_size,) + input_shape), dtype=np.float16 + ) + # apply the k lip linear transformation + batch_y = batch_x + yield batch_x, batch_y + + +@pytest.mark.parametrize( + "norm_type", + [LayerCentering, BatchCentering], +) +@pytest.mark.parametrize( + "size, input_shape, bias", + [ + (10, (10,), True), + (7, (7, 8, 8), True), + ], +) +def test_Normalization_bias(norm_type, size, input_shape, bias): + m = uft.generate_k_lip_model( + norm_type, + layer_params={"size": size, "bias": bias}, + input_shape=input_shape, + k=1, + ) + if m is None: + pytest.skip() + loss, optimizer, _ = uft.compile_model( + m, + optimizer=uft.get_instance_framework(uft.SGD, inst_params={"model": m}), + loss=uft.CategoricalCrossentropy(from_logits=True), + ) + batch_size = 10 + bb = uft.to_numpy(uft.get_layer_by_index(m, 0).bias) + np.testing.assert_allclose(bb, np.zeros((size,)), atol=1e-5) + + traind_ds = linear_generator(batch_size, input_shape) + uft.train( + traind_ds, + m, + loss, + optimizer, + 2, + batch_size, + steps_per_epoch=10, + ) + + bb = uft.to_numpy(uft.get_layer_by_index(m, 0).bias) + assert np.linalg.norm(bb) != 0.0 + + +@pytest.mark.parametrize( + "size, input_shape, bias", + [ + (4, (3, 4), False), + (4, (3, 4), True), + (4, (3, 4, 8, 8), False), + (4, (3, 4, 8, 8), True), + ], +) +def test_BatchCentering_runningmean(size, input_shape, bias): + """evaluate batch centering convergence of running mean""" + input_shape = uft.to_framework_channel(input_shape) + # start with 0 to set up running mean to zero + x = np.zeros(input_shape) + bn = uft.get_instance_framework(BatchCentering, {"size": size, "bias": bias}) + x = uft.to_tensor(x) + y = bn(x) + + np.testing.assert_allclose(bn.running_mean, 0.0, atol=1e-5) + + x = np.random.normal(0.0, 1.0, input_shape) + if len(input_shape) == 2: + mean_x = np.mean(x, axis=0) + else: + mean_x = np.mean(x, axis=(0, 2, 3)) + x = uft.to_tensor(x) + for _ in range(1000): + y = bn(x) + + np.testing.assert_allclose(bn.running_mean, mean_x, atol=1e-5) diff --git a/tests/utils_framework.py b/tests/utils_framework.py index e18e211..fc16035 100644 --- a/tests/utils_framework.py +++ b/tests/utils_framework.py @@ -45,6 +45,8 @@ from deel.torchlip.modules import InvertibleDownSampling from deel.torchlip.modules import InvertibleUpSampling from deel.torchlip.modules import Reshape as tReshape +from deel.torchlip.modules import LayerCentering +from deel.torchlip.modules import BatchCentering from deel.torchlip.utils import evaluate_lip_const from deel.torchlip.modules import ( @@ -104,6 +106,8 @@ "FrobeniusConv2d", "InvertibleDownSampling", "InvertibleUpSampling", + "LayerCentering", + "BatchCentering", "evaluate_lip_const", "DEFAULT_EPS_SPECTRAL", "invertible_downsample", @@ -476,6 +480,10 @@ def get_layer_weights_by_index(model, layer_idx): return get_layer_weights(model[layer_idx]) +def get_layer_by_index(model, layer_idx): + return model[layer_idx] + + # .weight.detach().cpu().numpy()