Skip to content

Commit fe95923

Browse files
committed
feat: 📝 add maniqa-koniq, kadid
1 parent ac70d7a commit fe95923

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

pyiqa/archs/maniqa_arch.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
import torch.nn.functional as F
1413
import timm
1514

1615
from timm.models.vision_transformer import Block
@@ -20,11 +19,12 @@
2019
from einops import rearrange
2120

2221
from pyiqa.utils.registry import ARCH_REGISTRY
23-
from .func_util import extract_2d_patches
2422
from pyiqa.archs.arch_util import load_pretrained_network
2523

2624
default_model_urls = {
27-
'pipal': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/MANIQA_PIPAL-ae6d356b.pth'
25+
'pipal': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/MANIQA_PIPAL-ae6d356b.pth',
26+
'koniq': 'https://github.com/IIGROUP/MANIQA/releases/download/Koniq10k/ckpt_koniq10k.pt',
27+
'kadid': 'https://github.com/IIGROUP/MANIQA/releases/download/Kadid10k/ckpt_kadid10k.pt',
2828
}
2929

3030

@@ -83,6 +83,7 @@ def __init__(self, embed_dim=768, num_outputs=1, patch_size=8, drop=0.1,
8383
img_size=224, num_tab=2, scale=0.13, test_sample=20,
8484
pretrained=True,
8585
pretrained_model_path=None,
86+
train_dataset='pipal',
8687
default_mean=None,
8788
default_std=None,
8889
**kwargs):
@@ -155,7 +156,7 @@ def __init__(self, embed_dim=768, num_outputs=1, patch_size=8, drop=0.1,
155156
load_pretrained_network(self, pretrained_model_path, True, weight_keys='params')
156157
# load_pretrained_network(self, pretrained_model_path, True, )
157158
elif pretrained:
158-
load_pretrained_network(self, default_model_urls['pipal'], True)
159+
load_pretrained_network(self, default_model_urls[train_dataset], True)
159160

160161
def extract_feature(self, save_output):
161162
x6 = save_output.outputs[6][:, 1:]

pyiqa/default_model_configs.py

+14
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,20 @@
261261
},
262262
'metric_mode': 'NR',
263263
},
264+
'maniqa-koniq': {
265+
'metric_opts': {
266+
'type': 'MANIQA',
267+
'train_dataset': 'koniq',
268+
},
269+
'metric_mode': 'NR',
270+
},
271+
'maniqa-kadid': {
272+
'metric_opts': {
273+
'type': 'MANIQA',
274+
'train_dataset': 'kadid',
275+
},
276+
'metric_mode': 'NR',
277+
},
264278
'clipiqa': {
265279
'metric_opts': {
266280
'type': 'CLIPIQA',

0 commit comments

Comments
 (0)