From 2ee84d94fd90e65776fc935dd7134684d7efb300 Mon Sep 17 00:00:00 2001 From: cofri Date: Tue, 23 May 2023 09:55:33 +0200 Subject: [PATCH 1/7] refactor: power_iteration function generic to any linear operator The function `power_iteration()` is now generic to any linear operator. A linear operator function and its adjoint must be passed as arguments. The maximum singular vector is returned. Note that this commit is atomic but is part of larger modifications. It won't pass the tests as is. --- deel/lip/normalizers.py | 86 +++++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/deel/lip/normalizers.py b/deel/lip/normalizers.py index 800c8ea1..6dd9a8ed 100644 --- a/deel/lip/normalizers.py +++ b/deel/lip/normalizers.py @@ -158,59 +158,69 @@ def body(w, old_w): return w -def _power_iteration(w, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL): - """ - Internal function that performs the power iteration algorithm. +def _power_iteration( + linear_operator, + adjoint_operator, + u, + eps=DEFAULT_EPS_SPECTRAL, + maxiter=DEFAULT_MAXITER_SPECTRAL, + big_constant=-1, +): + """Internal function that performs the power iteration algorithm to estimate the + largest singular vector of a linear operator. Args: - w: weights matrix that we want to find eigen vector - u: initialization of the eigen vector - eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps - maxiter: maximum number of iterations for the algorithm + linear_operator (Callable): a callable object that maps a linear operation. + adjoint_operator (Callable): a callable object that maps the adjoint of the + linear operator. + u (tf.Tensor): initialization of the singular vector. + eps (float, optional): stopping criterion of the algorithm, when + norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL. + maxiter (int, optional): maximum number of iterations for the algorithm. + Defaults to DEFAULT_MAXITER_SPECTRAL. + big_constant (int, optional): Set to a large value to compute the minimum + singular value. Defaults to -1, to compute the maximum singular value. Returns: - u and v corresponding to the maximum eigenvalue - + tf.Tensor: the maximum singular vector. """ - # build _u and _v (_v is size of _u@tf.transpose(w), will be set on the first body - # iteration) - if u is None: - u = tf.linalg.l2_normalize( - tf.random.uniform( - shape=(1, w.shape[-1]), minval=0.0, maxval=1.0, dtype=w.dtype - ) - ) - _u = u - _v = tf.zeros((1,) + (w.shape[0],), dtype=w.dtype) - # create a fake old_w that doesn't pass the loop condition - # it won't affect computation as the first action done in the loop overwrite it. - _old_u = 10 * _u + # Prepare while loop variables + u = tf.math.l2_normalize(u) + # create a fake old_w that doesn't pass the loop condition, it will be overwritten + old_u = u + 2 * eps - # define the loop condition - def cond(_u, _v, old_u): - return tf.linalg.norm(_u - old_u) >= eps + # Loop body + def body(u, old_u): + old_u = u + v = linear_operator(u) + u = adjoint_operator(v) - # define the loop body - def body(_u, _v, _old_u): - _old_u = _u - _v = tf.math.l2_normalize(_u @ tf.transpose(w)) - _u = tf.math.l2_normalize(_v @ w) - return _u, _v, _old_u + if big_constant > 0: + u = big_constant * old_u - u - # apply the loop - _u, _v, _old_u = tf.while_loop( + u = tf.math.l2_normalize(u) + + return u, old_u + + # Loop stopping condition + def cond(u, old_u): + return tf.linalg.norm(u - old_u) >= eps + + # Run the while loop + u, _ = tf.while_loop( cond, body, - (_u, _v, _old_u), - parallel_iterations=30, + (u, old_u), maximum_iterations=maxiter, swap_memory=SWAP_MEMORY, ) + + # Prevent gradient to back-propagate into the while loop if STOP_GRAD_SPECTRAL: - _u = tf.stop_gradient(_u) - _v = tf.stop_gradient(_v) - return _u, _v + u = tf.stop_gradient(u) + + return u def spectral_normalization( From 647247fb70499cd58b78c2b2ca364e72727c9826 Mon Sep 17 00:00:00 2001 From: cofri Date: Tue, 23 May 2023 10:10:54 +0200 Subject: [PATCH 2/7] refactor: spectral_normalization using generic power iteration The function `spectral_normalization()` is now based on the generic power iteration function. --- deel/lip/normalizers.py | 47 ++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/deel/lip/normalizers.py b/deel/lip/normalizers.py index 6dd9a8ed..f3c2d20a 100644 --- a/deel/lip/normalizers.py +++ b/deel/lip/normalizers.py @@ -227,30 +227,43 @@ def spectral_normalization( kernel, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL ): """ - Normalize the kernel to have it's max eigenvalue == 1. + Normalize the kernel to have its maximum singular value equal to 1. Args: - kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel - u (tf.Tensor): initialization for the max eigen vector - eps (float): epsilon stopping criterion: norm(ut - ut-1) must be less than eps - maxiter (int): maximum number of iterations for the algorithm + kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel. + u (tf.Tensor): initialization of the maximum singular vector. + eps (float, optional): stopping criterion of the algorithm, when + norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL. + maxiter (int, optional): maximum number of iterations for the algorithm. + Defaults to DEFAULT_MAXITER_SPECTRAL. Returns: - the normalized kernel w_bar, the maximum eigen vector, and the maximum singular + the normalized kernel, the maximum singular vector, and the maximum singular value. - """ - _u, _v = _power_iteration(kernel, u, eps, maxiter) - # compute Sigma - sigma = _v @ kernel - sigma = sigma @ tf.transpose(_u) - # normalize it - # we assume that in the worst case we converged to sigma + eps (as u and v are + + if u is None: + u = tf.random.uniform( + shape=(1, kernel.shape[-1]), minval=0.0, maxval=1.0, dtype=kernel.dtype + ) + + def linear_op(u): + return u @ tf.transpose(kernel) + + def adjoint_op(v): + return v @ kernel + + u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter) + + # Compute the largest singular value and the normalized kernel. + # We assume that in the worst case we converged to sigma + eps (as u and v are # normalized after each iteration) - # in order to be sure that operator norm of W_bar is strictly less than one we - # use sigma + eps, which ensure stability of the bjorck even when beta=0.5 - W_bar = kernel / (sigma + eps) - return W_bar, _u, sigma + # In order to be sure that operator norm of normalized kernel is strictly less than + # one we use sigma + eps, which ensures stability of Björck algorithm even when + # beta=0.5 + sigma = tf.reshape(tf.norm(linear_op(u)), (1, 1)) + normalized_kernel = kernel / (sigma + eps) + return normalized_kernel, u, sigma def _power_iteration_conv( From 1487cb21cf19e8ea8053af691949325e5d67201f Mon Sep 17 00:00:00 2001 From: cofri Date: Tue, 23 May 2023 13:55:12 +0200 Subject: [PATCH 3/7] refactor: spectral_normalization_conv with generic power iteration The function spectral_normalization_conv() now uses the generic power iteration. Unlike spectral normalization for matrices, a get_convolution_operators() function is introduced because it can be required in some future operations. --- deel/lip/normalizers.py | 113 +++++++++++++--------------------------- 1 file changed, 35 insertions(+), 78 deletions(-) diff --git a/deel/lip/normalizers.py b/deel/lip/normalizers.py index f3c2d20a..d363b494 100644 --- a/deel/lip/normalizers.py +++ b/deel/lip/normalizers.py @@ -266,31 +266,23 @@ def adjoint_op(v): return normalized_kernel, u, sigma -def _power_iteration_conv( - w, - u, - stride=1.0, - conv_first=True, - pad_func=None, - eps=DEFAULT_EPS_SPECTRAL, - maxiter=DEFAULT_MAXITER_SPECTRAL, - big_constant=-1, -): +def get_conv_operators(kernel, u_shape, stride=1.0, conv_first=True, pad_func=None): """ - Internal function that performs the power iteration algorithm for convolution. + Return two functions corresponding to the linear convolution operator and its + adjoint. Args: - w: weights matrix that we want to find eigen vector - u: initialization of the eigen matrix should be ||u||=1 for L2_norm - stride: stride parameter of the convolution - conv_first: RO or CO case , should be True in CO case (stride^2*C 0: - unew = big_constant * u - unew - - _norm_unew = tf.norm(unew) - unew = tf.math.l2_normalize(unew) - return unew, v, _old_u, _norm_unew + if conv_first: - # define the loop condition + def linear_op(u): + return _conv(u, kernel, stride) - def cond(_u, _v, old_u, _norm_u): - return tf.linalg.norm(_u - old_u) >= eps + def adjoint_op(v): + return _conv_transpose(v, kernel, u_shape, stride) - # v shape - if conv_first: - v_shape = ( - (u.shape[0],) - + (u.shape[1] // stride, u.shape[2] // stride) - + (w.shape[-1],) - ) else: v_shape = ( - (u.shape[0],) + (u.shape[1] * stride, u.shape[2] * stride) + (w.shape[-2],) + (u_shape[0],) + + (u_shape[1] * stride, u_shape[2] * stride) + + (kernel.shape[-2],) ) - # build _u and _v - _norm_u = tf.norm(u) - _u = tf.math.l2_normalize(u) - _u += tf.random.uniform(_u.shape, minval=-eps, maxval=eps) - _v = tf.zeros(v_shape) # _v will be set on the first body iteration + def linear_op(u): + return _conv_transpose(u, kernel, v_shape, stride) - # create a fake old_w that doesn't pass the loop condition - # it won't affect computation as the first action done in the loop overwrites it. - _old_u = 10 * _u - - # apply the loop - _u, _v, _old_u, _norm_u = tf.while_loop( - cond, - body, - (_u, _v, _old_u, _norm_u), - parallel_iterations=1, - maximum_iterations=maxiter, - swap_memory=SWAP_MEMORY, - ) - if STOP_GRAD_SPECTRAL: - _u = tf.stop_gradient(_u) - _v = tf.stop_gradient(_v) + def adjoint_op(v): + return _conv(v, kernel, stride) - return _u, _v, _norm_u + return linear_op, adjoint_op def spectral_normalization_conv( @@ -409,11 +363,14 @@ def spectral_normalization_conv( if eps < 0: return kernel, u, 1.0 - _u, _v, _ = _power_iteration_conv( - kernel, u, stride, conv_first, pad_func, eps, maxiter + linear_op, adjoint_op = get_conv_operators( + kernel, u.shape, stride, conv_first, pad_func ) - # Calculate Sigma - sigma = tf.norm(_v) - W_bar = kernel / (sigma + eps) - return W_bar, _u, sigma + u = tf.math.l2_normalize(u) + tf.random.uniform(u.shape, minval=-eps, maxval=eps) + u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter) + + # Compute the largest singular value and the normalized kernel + sigma = tf.norm(linear_op(u)) + normalized_kernel = kernel / (sigma + eps) + return normalized_kernel, u, sigma From 08e941bcc932305dd7e322518591945cd49e146b Mon Sep 17 00:00:00 2001 From: cofri Date: Wed, 12 Jul 2023 16:37:27 +0200 Subject: [PATCH 4/7] feat: axis for normalization in power iteration Setting the axis for normalization is useful to handle depthwise convolution filterwise --- deel/lip/normalizers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deel/lip/normalizers.py b/deel/lip/normalizers.py index d363b494..0311edd4 100644 --- a/deel/lip/normalizers.py +++ b/deel/lip/normalizers.py @@ -164,6 +164,7 @@ def _power_iteration( u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL, + axis=None, big_constant=-1, ): """Internal function that performs the power iteration algorithm to estimate the @@ -178,6 +179,8 @@ def _power_iteration( norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL. maxiter (int, optional): maximum number of iterations for the algorithm. Defaults to DEFAULT_MAXITER_SPECTRAL. + axis (int/list, optional): dimension along which to normalize. Can be set for + depthwise convolution for example. Defaults to None. big_constant (int, optional): Set to a large value to compute the minimum singular value. Defaults to -1, to compute the maximum singular value. @@ -186,7 +189,7 @@ def _power_iteration( """ # Prepare while loop variables - u = tf.math.l2_normalize(u) + u = tf.math.l2_normalize(u, axis=axis) # create a fake old_w that doesn't pass the loop condition, it will be overwritten old_u = u + 2 * eps @@ -199,7 +202,7 @@ def body(u, old_u): if big_constant > 0: u = big_constant * old_u - u - u = tf.math.l2_normalize(u) + u = tf.math.l2_normalize(u, axis=axis) return u, old_u From e0ca6326b572681eb3ab4cd173436e1c5da69ee6 Mon Sep 17 00:00:00 2001 From: cofri Date: Tue, 17 Oct 2023 10:19:24 +0200 Subject: [PATCH 5/7] feat: remove computation of min SV in power iteration --- deel/lip/normalizers.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/deel/lip/normalizers.py b/deel/lip/normalizers.py index 0311edd4..e8abf778 100644 --- a/deel/lip/normalizers.py +++ b/deel/lip/normalizers.py @@ -165,7 +165,6 @@ def _power_iteration( eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL, axis=None, - big_constant=-1, ): """Internal function that performs the power iteration algorithm to estimate the largest singular vector of a linear operator. @@ -181,8 +180,6 @@ def _power_iteration( Defaults to DEFAULT_MAXITER_SPECTRAL. axis (int/list, optional): dimension along which to normalize. Can be set for depthwise convolution for example. Defaults to None. - big_constant (int, optional): Set to a large value to compute the minimum - singular value. Defaults to -1, to compute the maximum singular value. Returns: tf.Tensor: the maximum singular vector. @@ -199,9 +196,6 @@ def body(u, old_u): v = linear_operator(u) u = adjoint_operator(v) - if big_constant > 0: - u = big_constant * old_u - u - u = tf.math.l2_normalize(u, axis=axis) return u, old_u From 2ba716dad0b76eff76d0dd557e1a00f0a7b14e8f Mon Sep 17 00:00:00 2001 From: cofri Date: Tue, 17 Oct 2023 10:33:02 +0200 Subject: [PATCH 6/7] feat: update setup.cfg with latest Python and TF versions --- setup.cfg | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index a02d1ef5..8c802b0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,8 +9,8 @@ per-file-ignores = [tox:tox] envlist = - py{37,38,39,310}-tf{22,23,24,25,26,27,28,29,210,211,212,213,latest} - py{37,38,39,310}-lint + py{37,38,39,310,311}-tf{22,23,24,25,26,27,28,29,210,211,212,213,214,latest} + py{37,38,39,310,311}-lint [testenv] deps = @@ -28,11 +28,12 @@ deps = tf211: tensorflow ~= 2.11.0 tf212: tensorflow ~= 2.12.0 tf213: tensorflow ~= 2.13.0 + tf214: tensorflow ~= 2.14.0 commands = python -m unittest -[testenv:py{37,38,39,310}-lint] +[testenv:py{37,38,39,310,311}-lint] skip_install = true deps = black From f09c2be3bff361b366dfd84a7cba75509e1b8820 Mon Sep 17 00:00:00 2001 From: cofri Date: Tue, 17 Oct 2023 11:52:38 +0200 Subject: [PATCH 7/7] test: enforce tensors to be float32 --- tests/test_normalizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_normalizers.py b/tests/test_normalizers.py index a193f965..07f190de 100644 --- a/tests/test_normalizers.py +++ b/tests/test_normalizers.py @@ -40,7 +40,7 @@ def _test_kernel(self, kernel): ).numpy() SVmax = np.max(sigmas_svd) - u = rng.normal(size=(1, kernel.shape[-1])) + u = rng.normal(size=(1, kernel.shape[-1])).astype("float32") W_bar, _u, sigma = spectral_normalization(kernel, u=u, eps=1e-6) # Test sigma is close to the one computed with svd first run @ 1e-1 np.testing.assert_approx_equal(