-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #73 from deel-ai/refactor/power_iteration_generic
Power iteration algorithm is now generic to any linear operator
- Loading branch information
Showing
3 changed files
with
115 additions
and
137 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,116 +158,128 @@ def body(w, old_w): | |
return w | ||
|
||
|
||
def _power_iteration(w, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL): | ||
""" | ||
Internal function that performs the power iteration algorithm. | ||
def _power_iteration( | ||
linear_operator, | ||
adjoint_operator, | ||
u, | ||
eps=DEFAULT_EPS_SPECTRAL, | ||
maxiter=DEFAULT_MAXITER_SPECTRAL, | ||
axis=None, | ||
): | ||
"""Internal function that performs the power iteration algorithm to estimate the | ||
largest singular vector of a linear operator. | ||
Args: | ||
w: weights matrix that we want to find eigen vector | ||
u: initialization of the eigen vector | ||
eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps | ||
maxiter: maximum number of iterations for the algorithm | ||
linear_operator (Callable): a callable object that maps a linear operation. | ||
adjoint_operator (Callable): a callable object that maps the adjoint of the | ||
linear operator. | ||
u (tf.Tensor): initialization of the singular vector. | ||
eps (float, optional): stopping criterion of the algorithm, when | ||
norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL. | ||
maxiter (int, optional): maximum number of iterations for the algorithm. | ||
Defaults to DEFAULT_MAXITER_SPECTRAL. | ||
axis (int/list, optional): dimension along which to normalize. Can be set for | ||
depthwise convolution for example. Defaults to None. | ||
Returns: | ||
u and v corresponding to the maximum eigenvalue | ||
tf.Tensor: the maximum singular vector. | ||
""" | ||
# build _u and _v (_v is size of [email protected](w), will be set on the first body | ||
# iteration) | ||
if u is None: | ||
u = tf.linalg.l2_normalize( | ||
tf.random.uniform( | ||
shape=(1, w.shape[-1]), minval=0.0, maxval=1.0, dtype=w.dtype | ||
) | ||
) | ||
_u = u | ||
_v = tf.zeros((1,) + (w.shape[0],), dtype=w.dtype) | ||
|
||
# create a fake old_w that doesn't pass the loop condition | ||
# it won't affect computation as the first action done in the loop overwrite it. | ||
_old_u = 10 * _u | ||
# Prepare while loop variables | ||
u = tf.math.l2_normalize(u, axis=axis) | ||
# create a fake old_w that doesn't pass the loop condition, it will be overwritten | ||
old_u = u + 2 * eps | ||
|
||
# define the loop condition | ||
def cond(_u, _v, old_u): | ||
return tf.linalg.norm(_u - old_u) >= eps | ||
# Loop body | ||
def body(u, old_u): | ||
old_u = u | ||
v = linear_operator(u) | ||
u = adjoint_operator(v) | ||
|
||
# define the loop body | ||
def body(_u, _v, _old_u): | ||
_old_u = _u | ||
_v = tf.math.l2_normalize(_u @ tf.transpose(w)) | ||
_u = tf.math.l2_normalize(_v @ w) | ||
return _u, _v, _old_u | ||
u = tf.math.l2_normalize(u, axis=axis) | ||
|
||
# apply the loop | ||
_u, _v, _old_u = tf.while_loop( | ||
return u, old_u | ||
|
||
# Loop stopping condition | ||
def cond(u, old_u): | ||
return tf.linalg.norm(u - old_u) >= eps | ||
|
||
# Run the while loop | ||
u, _ = tf.while_loop( | ||
cond, | ||
body, | ||
(_u, _v, _old_u), | ||
parallel_iterations=30, | ||
(u, old_u), | ||
maximum_iterations=maxiter, | ||
swap_memory=SWAP_MEMORY, | ||
) | ||
|
||
# Prevent gradient to back-propagate into the while loop | ||
if STOP_GRAD_SPECTRAL: | ||
_u = tf.stop_gradient(_u) | ||
_v = tf.stop_gradient(_v) | ||
return _u, _v | ||
u = tf.stop_gradient(u) | ||
|
||
return u | ||
|
||
|
||
def spectral_normalization( | ||
kernel, u, eps=DEFAULT_EPS_SPECTRAL, maxiter=DEFAULT_MAXITER_SPECTRAL | ||
): | ||
""" | ||
Normalize the kernel to have it's max eigenvalue == 1. | ||
Normalize the kernel to have its maximum singular value equal to 1. | ||
Args: | ||
kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel | ||
u (tf.Tensor): initialization for the max eigen vector | ||
eps (float): epsilon stopping criterion: norm(ut - ut-1) must be less than eps | ||
maxiter (int): maximum number of iterations for the algorithm | ||
kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel. | ||
u (tf.Tensor): initialization of the maximum singular vector. | ||
eps (float, optional): stopping criterion of the algorithm, when | ||
norm(u[t] - u[t-1]) is less than eps. Defaults to DEFAULT_EPS_SPECTRAL. | ||
maxiter (int, optional): maximum number of iterations for the algorithm. | ||
Defaults to DEFAULT_MAXITER_SPECTRAL. | ||
Returns: | ||
the normalized kernel w_bar, the maximum eigen vector, and the maximum singular | ||
the normalized kernel, the maximum singular vector, and the maximum singular | ||
value. | ||
""" | ||
_u, _v = _power_iteration(kernel, u, eps, maxiter) | ||
# compute Sigma | ||
sigma = _v @ kernel | ||
sigma = sigma @ tf.transpose(_u) | ||
# normalize it | ||
# we assume that in the worst case we converged to sigma + eps (as u and v are | ||
|
||
if u is None: | ||
u = tf.random.uniform( | ||
shape=(1, kernel.shape[-1]), minval=0.0, maxval=1.0, dtype=kernel.dtype | ||
) | ||
|
||
def linear_op(u): | ||
return u @ tf.transpose(kernel) | ||
|
||
def adjoint_op(v): | ||
return v @ kernel | ||
|
||
u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter) | ||
|
||
# Compute the largest singular value and the normalized kernel. | ||
# We assume that in the worst case we converged to sigma + eps (as u and v are | ||
# normalized after each iteration) | ||
# in order to be sure that operator norm of W_bar is strictly less than one we | ||
# use sigma + eps, which ensure stability of the bjorck even when beta=0.5 | ||
W_bar = kernel / (sigma + eps) | ||
return W_bar, _u, sigma | ||
# In order to be sure that operator norm of normalized kernel is strictly less than | ||
# one we use sigma + eps, which ensures stability of Björck algorithm even when | ||
# beta=0.5 | ||
sigma = tf.reshape(tf.norm(linear_op(u)), (1, 1)) | ||
normalized_kernel = kernel / (sigma + eps) | ||
return normalized_kernel, u, sigma | ||
|
||
|
||
def _power_iteration_conv( | ||
w, | ||
u, | ||
stride=1.0, | ||
conv_first=True, | ||
pad_func=None, | ||
eps=DEFAULT_EPS_SPECTRAL, | ||
maxiter=DEFAULT_MAXITER_SPECTRAL, | ||
big_constant=-1, | ||
): | ||
def get_conv_operators(kernel, u_shape, stride=1.0, conv_first=True, pad_func=None): | ||
""" | ||
Internal function that performs the power iteration algorithm for convolution. | ||
Return two functions corresponding to the linear convolution operator and its | ||
adjoint. | ||
Args: | ||
w: weights matrix that we want to find eigen vector | ||
u: initialization of the eigen matrix should be ||u||=1 for L2_norm | ||
stride: stride parameter of the convolution | ||
conv_first: RO or CO case , should be True in CO case (stride^2*C<M) | ||
pad_func: function for applying padding (None is padding same) | ||
eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps | ||
maxiter: maximum number of iterations for the algorithm | ||
big_constant: only for computing the minimum singular value (otherwise -1) | ||
Returns: | ||
u and v corresponding to the maximum eigenvalue | ||
kernel (tf.Tensor): the convolution kernel to normalize | ||
u_shape (tuple): shape of a singular vector (as a 4D tensor). | ||
stride (int, optional): stride parameter of convolutions. Defaults to 1. | ||
conv_first (bool, optional): RO or CO case , should be True in CO case | ||
(stride^2*C<M). Defaults to True. | ||
pad_func (Callable, optional): function for applying padding (None is padding | ||
same). Defaults to None. | ||
Returns: | ||
tuple: two functions for the linear convolution operator and its adjoint | ||
operator. | ||
""" | ||
|
||
def identity(x): | ||
|
@@ -295,66 +307,28 @@ def _conv_transpose(u, w, output_shape, stride): | |
w_adj = _maybe_transpose_kernel(w, True) | ||
return _conv(u_upscale, w_adj, stride=1) | ||
|
||
def body(_u, _v, _old_u, _norm_u): | ||
# _u is supposed to be normalized when entering in the body function | ||
_old_u = _u | ||
u = _u | ||
|
||
if conv_first: # Conv, then transposed conv | ||
v = _conv(u, w, stride) | ||
unew = _conv_transpose(v, w, u.shape, stride) | ||
else: # Transposed conv, then conv | ||
v = _conv_transpose(u, w, _v.shape, stride) | ||
unew = _conv(v, w, stride) | ||
|
||
if big_constant > 0: | ||
unew = big_constant * u - unew | ||
if conv_first: | ||
|
||
_norm_unew = tf.norm(unew) | ||
unew = tf.math.l2_normalize(unew) | ||
return unew, v, _old_u, _norm_unew | ||
def linear_op(u): | ||
return _conv(u, kernel, stride) | ||
|
||
# define the loop condition | ||
def adjoint_op(v): | ||
return _conv_transpose(v, kernel, u_shape, stride) | ||
|
||
def cond(_u, _v, old_u, _norm_u): | ||
return tf.linalg.norm(_u - old_u) >= eps | ||
|
||
# v shape | ||
if conv_first: | ||
v_shape = ( | ||
(u.shape[0],) | ||
+ (u.shape[1] // stride, u.shape[2] // stride) | ||
+ (w.shape[-1],) | ||
) | ||
else: | ||
v_shape = ( | ||
(u.shape[0],) + (u.shape[1] * stride, u.shape[2] * stride) + (w.shape[-2],) | ||
(u_shape[0],) | ||
+ (u_shape[1] * stride, u_shape[2] * stride) | ||
+ (kernel.shape[-2],) | ||
) | ||
|
||
# build _u and _v | ||
_norm_u = tf.norm(u) | ||
_u = tf.math.l2_normalize(u) | ||
_u += tf.random.uniform(_u.shape, minval=-eps, maxval=eps) | ||
_v = tf.zeros(v_shape) # _v will be set on the first body iteration | ||
def linear_op(u): | ||
return _conv_transpose(u, kernel, v_shape, stride) | ||
|
||
# create a fake old_w that doesn't pass the loop condition | ||
# it won't affect computation as the first action done in the loop overwrites it. | ||
_old_u = 10 * _u | ||
def adjoint_op(v): | ||
return _conv(v, kernel, stride) | ||
|
||
# apply the loop | ||
_u, _v, _old_u, _norm_u = tf.while_loop( | ||
cond, | ||
body, | ||
(_u, _v, _old_u, _norm_u), | ||
parallel_iterations=1, | ||
maximum_iterations=maxiter, | ||
swap_memory=SWAP_MEMORY, | ||
) | ||
if STOP_GRAD_SPECTRAL: | ||
_u = tf.stop_gradient(_u) | ||
_v = tf.stop_gradient(_v) | ||
|
||
return _u, _v, _norm_u | ||
return linear_op, adjoint_op | ||
|
||
|
||
def spectral_normalization_conv( | ||
|
@@ -386,11 +360,14 @@ def spectral_normalization_conv( | |
if eps < 0: | ||
return kernel, u, 1.0 | ||
|
||
_u, _v, _ = _power_iteration_conv( | ||
kernel, u, stride, conv_first, pad_func, eps, maxiter | ||
linear_op, adjoint_op = get_conv_operators( | ||
kernel, u.shape, stride, conv_first, pad_func | ||
) | ||
|
||
# Calculate Sigma | ||
sigma = tf.norm(_v) | ||
W_bar = kernel / (sigma + eps) | ||
return W_bar, _u, sigma | ||
u = tf.math.l2_normalize(u) + tf.random.uniform(u.shape, minval=-eps, maxval=eps) | ||
u = _power_iteration(linear_op, adjoint_op, u, eps, maxiter) | ||
|
||
# Compute the largest singular value and the normalized kernel | ||
sigma = tf.norm(linear_op(u)) | ||
normalized_kernel = kernel / (sigma + eps) | ||
return normalized_kernel, u, sigma |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters