diff --git a/deel/lip/normalizers.py b/deel/lip/normalizers.py index 800c8ea1..e8abf778 100644 --- a/deel/lip/normalizers.py +++ b/deel/lip/normalizers.py @@ -158,116 +158,128 @@ def body(w, old_w): return w -def _power_iteration(w, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL): - """ - Internal function that performs the power iteration algorithm. +def _power_iteration( + linear_operator, + adjoint_operator, + u, + eps=DEFAULT_EPS_SPECTRAL, + maxiter=DEFAULT_MAXITER_SPECTRAL, + axis=None, +): + """Internal function that performs the power iteration algorithm to estimate the + largest singular vector of a linear operator. Args: - w: weights matrix that we want to find eigen vector - u: initialization of the eigen vector - eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps - maxiter: maximum number of iterations for the algorithm + linear_operator (Callable): a callable object that maps a linear operation. + adjoint_operator (Callable): a callable object that maps the adjoint of the + linear operator. + u (tf.Tensor): initialization of the singular vector. + eps (float, optional): stopping criterion of the algorithm, when + norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL. + maxiter (int, optional): maximum number of iterations for the algorithm. + Defaults to DEFAULT_MAXITER_SPECTRAL. + axis (int/list, optional): dimension along which to normalize. Can be set for + depthwise convolution for example. Defaults to None. Returns: - u and v corresponding to the maximum eigenvalue - + tf.Tensor: the maximum singular vector. """ - # build _u and _v (_v is size of _u@tf.transpose(w), will be set on the first body - # iteration) - if u is None: - u = tf.linalg.l2_normalize( - tf.random.uniform( - shape=(1, w.shape[-1]), minval=0.0, maxval=1.0, dtype=w.dtype - ) - ) - _u = u - _v = tf.zeros((1,) + (w.shape[0],), dtype=w.dtype) - # create a fake old_w that doesn't pass the loop condition - # it won't affect computation as the first action done in the loop overwrite it. - _old_u = 10 * _u + # Prepare while loop variables + u = tf.math.l2_normalize(u, axis=axis) + # create a fake old_w that doesn't pass the loop condition, it will be overwritten + old_u = u + 2 * eps - # define the loop condition - def cond(_u, _v, old_u): - return tf.linalg.norm(_u - old_u) >= eps + # Loop body + def body(u, old_u): + old_u = u + v = linear_operator(u) + u = adjoint_operator(v) - # define the loop body - def body(_u, _v, _old_u): - _old_u = _u - _v = tf.math.l2_normalize(_u @ tf.transpose(w)) - _u = tf.math.l2_normalize(_v @ w) - return _u, _v, _old_u + u = tf.math.l2_normalize(u, axis=axis) - # apply the loop - _u, _v, _old_u = tf.while_loop( + return u, old_u + + # Loop stopping condition + def cond(u, old_u): + return tf.linalg.norm(u - old_u) >= eps + + # Run the while loop + u, _ = tf.while_loop( cond, body, - (_u, _v, _old_u), - parallel_iterations=30, + (u, old_u), maximum_iterations=maxiter, swap_memory=SWAP_MEMORY, ) + + # Prevent gradient to back-propagate into the while loop if STOP_GRAD_SPECTRAL: - _u = tf.stop_gradient(_u) - _v = tf.stop_gradient(_v) - return _u, _v + u = tf.stop_gradient(u) + + return u def spectral_normalization( kernel, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL ): """ - Normalize the kernel to have it's max eigenvalue == 1. + Normalize the kernel to have its maximum singular value equal to 1. Args: - kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel - u (tf.Tensor): initialization for the max eigen vector - eps (float): epsilon stopping criterion: norm(ut - ut-1) must be less than eps - maxiter (int): maximum number of iterations for the algorithm + kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel. + u (tf.Tensor): initialization of the maximum singular vector. + eps (float, optional): stopping criterion of the algorithm, when + norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL. + maxiter (int, optional): maximum number of iterations for the algorithm. + Defaults to DEFAULT_MAXITER_SPECTRAL. Returns: - the normalized kernel w_bar, the maximum eigen vector, and the maximum singular + the normalized kernel, the maximum singular vector, and the maximum singular value. - """ - _u, _v = _power_iteration(kernel, u, eps, maxiter) - # compute Sigma - sigma = _v @ kernel - sigma = sigma @ tf.transpose(_u) - # normalize it - # we assume that in the worst case we converged to sigma + eps (as u and v are + + if u is None: + u = tf.random.uniform( + shape=(1, kernel.shape[-1]), minval=0.0, maxval=1.0, dtype=kernel.dtype + ) + + def linear_op(u): + return u @ tf.transpose(kernel) + + def adjoint_op(v): + return v @ kernel + + u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter) + + # Compute the largest singular value and the normalized kernel. + # We assume that in the worst case we converged to sigma + eps (as u and v are # normalized after each iteration) - # in order to be sure that operator norm of W_bar is strictly less than one we - # use sigma + eps, which ensure stability of the bjorck even when beta=0.5 - W_bar = kernel / (sigma + eps) - return W_bar, _u, sigma + # In order to be sure that operator norm of normalized kernel is strictly less than + # one we use sigma + eps, which ensures stability of Björck algorithm even when + # beta=0.5 + sigma = tf.reshape(tf.norm(linear_op(u)), (1, 1)) + normalized_kernel = kernel / (sigma + eps) + return normalized_kernel, u, sigma -def _power_iteration_conv( - w, - u, - stride=1.0, - conv_first=True, - pad_func=None, - eps=DEFAULT_EPS_SPECTRAL, - maxiter=DEFAULT_MAXITER_SPECTRAL, - big_constant=-1, -): +def get_conv_operators(kernel, u_shape, stride=1.0, conv_first=True, pad_func=None): """ - Internal function that performs the power iteration algorithm for convolution. + Return two functions corresponding to the linear convolution operator and its + adjoint. Args: - w: weights matrix that we want to find eigen vector - u: initialization of the eigen matrix should be ||u||=1 for L2_norm - stride: stride parameter of the convolution - conv_first: RO or CO case , should be True in CO case (stride^2*C 0: - unew = big_constant * u - unew + if conv_first: - _norm_unew = tf.norm(unew) - unew = tf.math.l2_normalize(unew) - return unew, v, _old_u, _norm_unew + def linear_op(u): + return _conv(u, kernel, stride) - # define the loop condition + def adjoint_op(v): + return _conv_transpose(v, kernel, u_shape, stride) - def cond(_u, _v, old_u, _norm_u): - return tf.linalg.norm(_u - old_u) >= eps - - # v shape - if conv_first: - v_shape = ( - (u.shape[0],) - + (u.shape[1] // stride, u.shape[2] // stride) - + (w.shape[-1],) - ) else: v_shape = ( - (u.shape[0],) + (u.shape[1] * stride, u.shape[2] * stride) + (w.shape[-2],) + (u_shape[0],) + + (u_shape[1] * stride, u_shape[2] * stride) + + (kernel.shape[-2],) ) - # build _u and _v - _norm_u = tf.norm(u) - _u = tf.math.l2_normalize(u) - _u += tf.random.uniform(_u.shape, minval=-eps, maxval=eps) - _v = tf.zeros(v_shape) # _v will be set on the first body iteration + def linear_op(u): + return _conv_transpose(u, kernel, v_shape, stride) - # create a fake old_w that doesn't pass the loop condition - # it won't affect computation as the first action done in the loop overwrites it. - _old_u = 10 * _u + def adjoint_op(v): + return _conv(v, kernel, stride) - # apply the loop - _u, _v, _old_u, _norm_u = tf.while_loop( - cond, - body, - (_u, _v, _old_u, _norm_u), - parallel_iterations=1, - maximum_iterations=maxiter, - swap_memory=SWAP_MEMORY, - ) - if STOP_GRAD_SPECTRAL: - _u = tf.stop_gradient(_u) - _v = tf.stop_gradient(_v) - - return _u, _v, _norm_u + return linear_op, adjoint_op def spectral_normalization_conv( @@ -386,11 +360,14 @@ def spectral_normalization_conv( if eps < 0: return kernel, u, 1.0 - _u, _v, _ = _power_iteration_conv( - kernel, u, stride, conv_first, pad_func, eps, maxiter + linear_op, adjoint_op = get_conv_operators( + kernel, u.shape, stride, conv_first, pad_func ) - # Calculate Sigma - sigma = tf.norm(_v) - W_bar = kernel / (sigma + eps) - return W_bar, _u, sigma + u = tf.math.l2_normalize(u) + tf.random.uniform(u.shape, minval=-eps, maxval=eps) + u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter) + + # Compute the largest singular value and the normalized kernel + sigma = tf.norm(linear_op(u)) + normalized_kernel = kernel / (sigma + eps) + return normalized_kernel, u, sigma diff --git a/setup.cfg b/setup.cfg index a02d1ef5..8c802b0b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,8 +9,8 @@ per-file-ignores = [tox:tox] envlist = - py{37,38,39,310}-tf{22,23,24,25,26,27,28,29,210,211,212,213,latest} - py{37,38,39,310}-lint + py{37,38,39,310,311}-tf{22,23,24,25,26,27,28,29,210,211,212,213,214,latest} + py{37,38,39,310,311}-lint [testenv] deps = @@ -28,11 +28,12 @@ deps = tf211: tensorflow ~= 2.11.0 tf212: tensorflow ~= 2.12.0 tf213: tensorflow ~= 2.13.0 + tf214: tensorflow ~= 2.14.0 commands = python -m unittest -[testenv:py{37,38,39,310}-lint] +[testenv:py{37,38,39,310,311}-lint] skip_install = true deps = black diff --git a/tests/test_normalizers.py b/tests/test_normalizers.py index a193f965..07f190de 100644 --- a/tests/test_normalizers.py +++ b/tests/test_normalizers.py @@ -40,7 +40,7 @@ def _test_kernel(self, kernel): ).numpy() SVmax = np.max(sigmas_svd) - u = rng.normal(size=(1, kernel.shape[-1])) + u = rng.normal(size=(1, kernel.shape[-1])).astype("float32") W_bar, _u, sigma = spectral_normalization(kernel, u=u, eps=1e-6) # Test sigma is close to the one computed with svd first run @ 1e-1 np.testing.assert_approx_equal(