Yinhuai Wang, Yujie Hu, Jiwen Yu, Jian Zhang
Peking University and PCL
The range-null space decomposition can improve the data consistency significantly.
Check our follow-up work DDNM.
Since our method is independent to the choice of networks, it is queit simple to apply our method to your own networks😊.
Let's take image SR task for example. Given an input LR image
$\mathbf{x}_{r}=\mathcal{D}(\mathbf{y})$
However, this SR result
We provide a very simple way to assure low-frequency consistency. We first define an average-pooling downsampler
$\mathbf{x}_{r}=\mathcal{D}(\mathbf{y})$
$\hat{\mathbf{x}}=\mathbf{A}^{\dagger}\mathbf{y}+(\mathbf{I}-\mathbf{A}^{\dagger}\mathbf{A})\mathbf{x}_{r}$
The result
Detailed implementation is here, you may copy these codes and apply to your own sr_backbone:
import torch
def MeanUpsample(vec,scale):
vec = vec.reshape(vec.size(0), 3, 1024//scale, 1024//scale)
vec = vec.repeat(1, 1, scale, scale)
vec = vec.reshape(vec.size(0), 3, scale, 1024//scale, scale, 1024//scale)
out = vec.permute(0, 1, 3, 2, 5, 4).reshape(vec.size(0), 3, 1024, 1024)
return out
class RND_sr(torch.nn.Module):
def __init__(self, image_size=1024, sr_scale=8, sr_backbone):
super(RND_sr, self).__init__()
self.sr_backbone = sr_backbone()
self.A = torch.nn.AdaptiveAvgPool2d((image_size//sr_scale,image_size//sr_scale))
self.Ap = lambda z: MeanUpsample(z,sr_scale)
def forward(self, x, y):
x = self.sr_backbone(y)
x = self.Ap(y) + x - self.Ap(self.A(x))
return x
You may try any sr_backbone. For the backbones used in this paper, you may refer to GLEAN and Panini.
import torch
def color2gray(x):
coef_r, coef_g, coef_b = 1/3, 1/3, 1/3
x = x[:,0,:,:]*coef_r + x[:,1,:,:]*coef_g + x[:,2,:,:]*coef_b
return x.repeat(1,3,1,1)
def gray2color(x):
coef_r, coef_g, coef_b = 1/3, 1/3, 1/3
x = x[:,0,:,:]
base = coef_r**2 + coef_g**2 + coef_b**2
return torch.stack((x*coef_r/base, x*coef_g/base, x*coef_b/base), 1)
class RND_colorization(torch.nn.Module):
def __init__(self, backbone):
super(RND_colorization, self).__init__()
self.backbone = backbone()
self.A = lambda z: color2gray(z)
self.Ap = lambda z: gray2color(z)
def forward(self, x, y):
x = self.backbone(y)
x = self.Ap(y) + x - self.Ap(self.A(x))
return x
import torch
#initialize sampling matrix
C = channel
B = block_size
N = B * B
q = int(torch.ceil(th.Tensor([cs_ratio * N])))
H, W = image_size, image_size
U, S, V = torch.linalg.svd(th.randn(N, N))
A_weight = U.mm(V) # CS sampling matrix weight
A_weight = A_weight.view(N, 1, B, B)[:q].to("cuda")
A = lambda z: torch.nn.functional.conv2d(z.view(batch_size * C, 1, H, W), A_weight, stride=B)
Ap = lambda z: torch.nn.functional.conv_transpose2d(z, A_weight, stride=B).view(batch_size, C, H, W)
class RND_cs(torch.nn.Module):
def __init__(self, backbone):
super(RND_cs, self).__init__()
self.backbone = backbone()
def forward(self, x, y):
x = self.backbone(y)
x = Ap(y) + x - Ap(A(x))
return x
When applied to small networks (e.g., ESRGAN) or complex dataset (e.g., DIV2K), RND may not bring significant improvement. Our method actually implies an assumption, that is, the network is "powerful" enough in generative performance. The "power" of the network is determined by the dataset and the network itself. For single-class datasets like the human faces, the data distribution is relatively easy to learn.