Skip to content

Commit ca64a91

Browse files
authored
Add files via upload
1 parent 6ed5d69 commit ca64a91

File tree

4 files changed

+858
-1
lines changed

4 files changed

+858
-1
lines changed

MSGlance/Global.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from .Local import LocalGlance
2+
import torch
3+
import torch.nn as nn
4+
from matplotlib import pyplot as plt
5+
6+
class GlobalGlance(nn.Module):
7+
def __init__(self, kernel_size=4, stride=4, repeat_time=10, patch_height=64, patch_width=64, drop=None, sigma=1.5):
8+
super(GlobalGlance, self).__init__()
9+
self.kernel_size = kernel_size
10+
self.stride = stride
11+
self.repeat_time = repeat_time
12+
self.patch_height = patch_height
13+
self.patch_width = patch_width
14+
self.sigma = sigma
15+
self.ssim_loss = LocalGlance(window_size=self.kernel_size, stride=self.stride, drop=drop, sigma=self.sigma)
16+
17+
def forward(self, src_vec, tar_vec):
18+
batch_size, channels, height, width = src_vec.size()
19+
loss = 0.0
20+
21+
for batch in range(batch_size):
22+
index_list = []
23+
for i in range(self.repeat_time):
24+
if i == 0:
25+
tmp_index = torch.arange(height * width)
26+
else:
27+
tmp_index = torch.randperm(height * width)
28+
index_list.append(tmp_index)
29+
30+
res_index = torch.cat(index_list)
31+
rows = res_index // width
32+
cols = res_index % width
33+
tar_all = tar_vec[batch, :, rows, cols].view(channels, -1, self.patch_height, self.patch_width * self.repeat_time)
34+
src_all = src_vec[batch, :, rows, cols].view(channels, -1, self.patch_height, self.patch_width * self.repeat_time)
35+
tar_mag = torch.clip(tar_all, 0, 1)*255
36+
src_mag = torch.clip(src_all, 0, 1)*255
37+
38+
loss += (1 - self.ssim_loss(src_mag, tar_mag))
39+
40+
loss /= batch_size
41+
return loss

MSGlance/Local.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
'''SSIM in PyTorch.
2+
3+
The source code is adopted from:
4+
https://github.com/Po-Hsun-Su/pytorch-ssim
5+
6+
7+
Reference:
8+
[1] Wang Z, Bovik A C, Sheikh H R, et al.
9+
Image quality assessment: from error visibility to structural similarity. IEEE transactions on image processing
10+
'''
11+
12+
import torch
13+
import torch.nn.functional as F
14+
from torch.autograd import Variable
15+
import numpy as np
16+
from math import exp
17+
import math
18+
19+
20+
# def gaussian(window_size, sigma):
21+
# gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
22+
# return gauss/gauss.sum()
23+
24+
def uniform(window_size,sigma):
25+
uniform_tensor = torch.ones(window_size)
26+
return uniform_tensor / uniform_tensor.sum()
27+
28+
29+
def create_window(window_size, channel, sigma=1.5):
30+
_1D_window = uniform(window_size, sigma).unsqueeze(1)
31+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
32+
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
33+
return window
34+
35+
def _ssim(img1, img2, window, window_size, channel, size_average = True, stride=None, drop=None):
36+
mu1 = F.conv2d(img1, window, padding = (window_size-1)//2, groups = channel, stride=stride)
37+
mu2 = F.conv2d(img2, window, padding = (window_size-1)//2, groups = channel, stride=stride)
38+
39+
mu1_sq = mu1.pow(2)
40+
mu2_sq = mu2.pow(2)
41+
mu1_mu2 = mu1*mu2
42+
43+
sigma1_sq = F.conv2d(img1*img1, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_sq
44+
sigma2_sq = F.conv2d(img2*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu2_sq
45+
sigma12 = F.conv2d(img1*img2, window, padding = (window_size-1)//2, groups = channel, stride=stride) - mu1_mu2
46+
47+
C1 = 0.01**2
48+
C2 = 0.03**2
49+
C3 = C2/2
50+
51+
L = (2*mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)
52+
C = (2*torch.sqrt(sigma1_sq)*torch.sqrt(sigma2_sq) + C2) / (sigma1_sq + sigma2_sq + C2)
53+
S = (sigma12 + C3) / (torch.sqrt(sigma1_sq)*torch.sqrt(sigma2_sq) + C3)
54+
55+
if drop == "L":
56+
ssim_map = C*S
57+
elif drop == "C":
58+
ssim_map = L*S
59+
elif drop == "S":
60+
ssim_map = L*C
61+
elif drop == "LC":
62+
ssim_map = S
63+
else:
64+
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
65+
66+
if size_average:
67+
return ssim_map.mean()
68+
else:
69+
return ssim_map.mean(1).mean(1).mean(1)
70+
71+
class LocalGlance(torch.nn.Module):
72+
def __init__(self, window_size = 3, size_average = True, stride=3, drop=None, sigma=1.5, channel=1):
73+
super(LocalGlance, self).__init__()
74+
self.window_size = window_size
75+
self.size_average = size_average
76+
self.channel = channel
77+
self.stride = stride
78+
self.window = create_window(window_size, self.channel, sigma)
79+
self.drop = drop
80+
self.sigma = sigma
81+
82+
83+
def forward(self, img1, img2):
84+
"""
85+
img1, img2: torch.Tensor([b,2,h,w]) - 2表示复数的实部和虚部
86+
"""
87+
# 计算幅值
88+
(_, channel, _, _) = img1.size()
89+
90+
if channel == self.channel and self.window.data.type() == img1.data.type():
91+
window = self.window
92+
else:
93+
window = create_window(self.window_size, channel, self.sigma)
94+
95+
if img1.is_cuda:
96+
window = window.cuda(img1.get_device())
97+
window = window.type_as(img1)
98+
99+
self.window = window
100+
self.channel = channel
101+
102+
return _ssim(img1, img2, window, self.window_size, channel, self.size_average, stride=self.stride, drop=self.drop)
103+
104+
105+
def ssim(img1, img2, window_size = 11, size_average = True, sigma=1.5):
106+
(_, channel, _, _) = img1.size()
107+
window = create_window(window_size, channel, sigma)
108+
109+
if img1.is_cuda:
110+
window = window.cuda(img1.get_device())
111+
window = window.type_as(img1)
112+
113+
return _ssim(img1, img2, window, window_size, channel, size_average)
114+
115+
116+
117+
118+
119+
120+
121+

README.md

+15-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
1-
1+
# [MS-Glance] WACV25 submission 1940
2+
3+
This repo contains the code of MSGlance.
4+
5+
## Content
6+
7+
The source code is located in the MSGlance folder.
8+
9+
We provide an easy-to-use Jupyter notebook, `explore_MSGlance.ipynb`, which demonstrates how to use MS-Glance for INR fitting with SIREN. Additionally, it includes the script used to generate Figure 1 in the Supplementary Materials.
10+
11+
## Usage
12+
13+
You can simply upload the files to your Google Drive and run them with Colab, which comes with all the necessary pre-installed packages to execute the `explore_MSGlance.ipynb` notebook.
14+
15+
I’d be happy to assist you with this, but doing so would reveal my username.

0 commit comments

Comments
 (0)