17
17
import torch .nn .functional as F
18
18
19
19
from pyiqa .utils .registry import ARCH_REGISTRY
20
- from pyiqa .archs .arch_util import load_pretrained_network
20
+ from pyiqa .archs .arch_util import load_pretrained_network , random_crop
21
21
from .func_util import extract_2d_patches
22
22
23
23
default_model_urls = {
@@ -115,12 +115,16 @@ def forward(self, dist, ref):
115
115
assert dist .shape == ref .shape , f'Input and reference images should have the same shape, but got { dist .shape } '
116
116
f' and { ref .shape } '
117
117
118
- if self .pretrained :
119
- dist = self .preprocess (dist )
120
- ref = self .preprocess (ref )
118
+ dist = self .preprocess (dist )
119
+ ref = self .preprocess (ref )
121
120
121
+ if not self .training :
122
122
image_A_patches = extract_2d_patches (dist , self .patch_size , self .stride , padding = 'none' )
123
123
image_ref_patches = extract_2d_patches (ref , self .patch_size , self .stride , padding = 'none' )
124
+ else :
125
+ image_A_patches , image_ref_patches = dist , ref
126
+ image_A_patches = image_A_patches .unsqueeze (1 )
127
+ image_ref_patches = image_ref_patches .unsqueeze (1 )
124
128
125
129
bsz , num_patches , c , psz , psz = image_A_patches .shape
126
130
image_A_patches = image_A_patches .reshape (bsz * num_patches , c , psz , psz )
@@ -138,4 +142,4 @@ def forward(self, dist, ref):
138
142
per_patch_weight = per_patch_weight .view ((- 1 , num_patches ))
139
143
140
144
score = (per_patch_weight * per_patch_score ).sum (dim = - 1 ) / per_patch_weight .sum (dim = - 1 )
141
- return score .squeeze ( )
145
+ return score .reshape ( bsz , 1 )
0 commit comments