Skip to content

Commit

Permalink
clean normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
franckma31 committed Oct 22, 2024
1 parent 3c61caf commit 0015925
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 26 deletions.
24 changes: 0 additions & 24 deletions deel/torchlip/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def forward(self, x):


LayerCentering2d = LayerCentering
# class LayerCentering2D(LayerCentering):
# def __init__(self, size = 1, dim=[-2,-1]):
# super(LayerCentering2D, self).__init__(size = size,dim=[-2,-1])


class BatchCentering(nn.Module):
Expand Down Expand Up @@ -72,25 +69,4 @@ def forward(self, x):
return x - mean.view(mean_shape)


# class BatchCenteringBiases(BatchCentering):
# def __init__(self, size =1, dim=[0,-2,-1], momentum=0.05):
# super(BatchCenteringBiases, self).__init__(size = size, dim = dim, momentum = momentum)
# if isinstance(size, tuple):
# self.alpha = nn.Parameter(torch.zeros(size), requires_grad=True)
# else:
# self.alpha = nn.Parameter(torch.zeros(1,size,1,1), requires_grad=True)

# def forward(self, x):
# #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(), self.running_mean.abs().cpu().mean().numpy(), self.alpha.abs().mean().cpu().numpy())
# #print(x.mean(dim=self.dim, keepdim=True).abs().mean().cpu().numpy(),(x.mean(dim=self.dim, keepdim=True)-self.running_mean).abs().mean().cpu().numpy())
# return super().forward(x) + self.alpha

BatchCentering2d = BatchCentering

# class BatchCenteringBiases2D(BatchCenteringBiases):
# def __init__(self, size =1, momentum=0.05):
# super(BatchCenteringBiases2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum)

# class BatchCentering2D(BatchCentering):
# def __init__(self, size =1, momentum=0.05):
# super(BatchCentering2D, self).__init__(size = size, dim=[0,-2,-1],momentum=momentum)
3 changes: 1 addition & 2 deletions tests/test_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
# =====================================================================================
import os
import pytest
from functools import partial

import numpy as np

Expand Down Expand Up @@ -249,6 +248,6 @@ def test_BatchCentering_runningmean(size, input_shape, bias):
mean_x = np.mean(x, axis=(0, 2, 3))
x = uft.to_tensor(x)
for _ in range(1000):
y = bn(x)
y = bn(x) # noqa: F841

np.testing.assert_allclose(bn.running_mean, mean_x, atol=1e-5)

0 comments on commit 0015925

Please sign in to comment.