Skip to content

Commit 036f2cc

Browse files
committed
feat: 🧑‍💻 improve codes to allow training
1 parent e73c719 commit 036f2cc

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

pyiqa/archs/dists_arch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,6 @@ def forward(self, x, y):
143143
S2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
144144
dist2 = dist2 + (beta[k] * S2).sum(1, keepdim=True)
145145

146-
score = 1 - (dist1 + dist2).squeeze()
146+
score = 1 - (dist1 + dist2)
147147

148-
return score
148+
return score.squeeze(-1).squeeze(-1)

pyiqa/archs/lpips_arch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def forward(self, in1, in0, retPerLayer=False, normalize=True):
173173
if (retPerLayer):
174174
return (val, res)
175175
else:
176-
return val.squeeze()
176+
return val.squeeze(-1).squeeze(-1)
177177

178178

179179
class ScalingLayer(nn.Module):

pyiqa/archs/pieapp_arch.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn.functional as F
1818

1919
from pyiqa.utils.registry import ARCH_REGISTRY
20-
from pyiqa.archs.arch_util import load_pretrained_network
20+
from pyiqa.archs.arch_util import load_pretrained_network, random_crop
2121
from .func_util import extract_2d_patches
2222

2323
default_model_urls = {
@@ -115,12 +115,16 @@ def forward(self, dist, ref):
115115
assert dist.shape == ref.shape, f'Input and reference images should have the same shape, but got {dist.shape}'
116116
f' and {ref.shape}'
117117

118-
if self.pretrained:
119-
dist = self.preprocess(dist)
120-
ref = self.preprocess(ref)
118+
dist = self.preprocess(dist)
119+
ref = self.preprocess(ref)
121120

121+
if not self.training:
122122
image_A_patches = extract_2d_patches(dist, self.patch_size, self.stride, padding='none')
123123
image_ref_patches = extract_2d_patches(ref, self.patch_size, self.stride, padding='none')
124+
else:
125+
image_A_patches, image_ref_patches = dist, ref
126+
image_A_patches = image_A_patches.unsqueeze(1)
127+
image_ref_patches = image_ref_patches.unsqueeze(1)
124128

125129
bsz, num_patches, c, psz, psz = image_A_patches.shape
126130
image_A_patches = image_A_patches.reshape(bsz * num_patches, c, psz, psz)
@@ -138,4 +142,4 @@ def forward(self, dist, ref):
138142
per_patch_weight = per_patch_weight.view((-1, num_patches))
139143

140144
score = (per_patch_weight * per_patch_score).sum(dim=-1) / per_patch_weight.sum(dim=-1)
141-
return score.squeeze()
145+
return score.reshape(bsz, 1)

0 commit comments

Comments
 (0)