diff --git a/deel/torchlip/init.py b/deel/torchlip/init.py index d4285b0..573cf29 100644 --- a/deel/torchlip/init.py +++ b/deel/torchlip/init.py @@ -24,8 +24,8 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -""" -""" +""" """ +import warnings import torch from .normalizers import bjorck_normalization @@ -57,6 +57,9 @@ def spectral_( eps_spectral (float): stopping criterion of iterative power method maxiter_spectral (int): maximum number of iterations for the power iteration """ + warnings.warn( + "spectral_ initialization is deprecated, use torch.nn.init.orthogonal_ instead" + ) with torch.no_grad(): tensor.copy_( spectral_normalization( @@ -91,6 +94,9 @@ def bjorck_( maxiter_bjorck (int): maximum number of iterations for bjorck algorithm beta: Value to use for the :math:`\beta` parameter. """ + warnings.warn( + "bjorck_ initialization is deprecated, use torch.nn.init.orthogonal_ instead" + ) with torch.no_grad(): spectral_tensor = spectral_normalization( tensor, None, eps=eps_spectral, maxiter=maxiter_spectral diff --git a/tests/test_compute_layer_sv.py b/tests/test_compute_layer_sv.py index 9feca85..1d830ae 100644 --- a/tests/test_compute_layer_sv.py +++ b/tests/test_compute_layer_sv.py @@ -24,8 +24,7 @@ # rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry, # CRIAQ and ANITI - https://www.deel.ai/ # ===================================================================================== -"""Tests for singular value computation (in compute_layer_sv.py) -""" +"""Tests for singular value computation (in compute_layer_sv.py)""" import os import pprint import pytest diff --git a/tests/utils_framework.py b/tests/utils_framework.py index a070176..4e326dd 100644 --- a/tests/utils_framework.py +++ b/tests/utils_framework.py @@ -64,7 +64,6 @@ CategoricalHingeLoss, ) -from deel.torchlip.init import spectral_, bjorck_ from deel.torchlip.normalizers import spectral_normalization from deel.torchlip.normalizers import bjorck_normalization from deel.torchlip.normalizers import DEFAULT_EPS_SPECTRAL @@ -174,6 +173,7 @@ def __call__(self, **kwargs): Model = module_Unavailable_class compute_layer_sv = module_Unavailable_class OrthLinearRegularizer = module_Unavailable_class +SpectralInitializer = module_Unavailable_class MODEL_PATH = "model.h5" LIP_LAYERS = "torchlip_layers" @@ -591,13 +591,6 @@ def scaleDivAlpha(alpha): return 1.0 / (1 + 1.0 / alpha) -def SpectralInitializer(eps_spectral, eps_bjorck): - if eps_bjorck is None: - return partial(spectral_, eps_spectral=eps_spectral) - else: - return partial(bjorck_, eps_spectral=eps_spectral, eps_bjorck=eps_bjorck) - - class tAdd(torch.nn.Module): def __init__(self): super(tAdd, self).__init__()