|
10 | 10 |
|
11 | 11 | import torch
|
12 | 12 | import torch.nn as nn
|
13 |
| -import torch.nn.functional as F |
14 | 13 | import timm
|
15 | 14 |
|
16 | 15 | from timm.models.vision_transformer import Block
|
|
20 | 19 | from einops import rearrange
|
21 | 20 |
|
22 | 21 | from pyiqa.utils.registry import ARCH_REGISTRY
|
23 |
| -from .func_util import extract_2d_patches |
24 | 22 | from pyiqa.archs.arch_util import load_pretrained_network
|
25 | 23 |
|
26 | 24 | default_model_urls = {
|
27 |
| - 'pipal': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/MANIQA_PIPAL-ae6d356b.pth' |
| 25 | + 'pipal': 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/MANIQA_PIPAL-ae6d356b.pth', |
| 26 | + 'koniq': 'https://github.com/IIGROUP/MANIQA/releases/download/Koniq10k/ckpt_koniq10k.pt', |
| 27 | + 'kadid': 'https://github.com/IIGROUP/MANIQA/releases/download/Kadid10k/ckpt_kadid10k.pt', |
28 | 28 | }
|
29 | 29 |
|
30 | 30 |
|
@@ -83,6 +83,7 @@ def __init__(self, embed_dim=768, num_outputs=1, patch_size=8, drop=0.1,
|
83 | 83 | img_size=224, num_tab=2, scale=0.13, test_sample=20,
|
84 | 84 | pretrained=True,
|
85 | 85 | pretrained_model_path=None,
|
| 86 | + train_dataset='pipal', |
86 | 87 | default_mean=None,
|
87 | 88 | default_std=None,
|
88 | 89 | **kwargs):
|
@@ -155,7 +156,7 @@ def __init__(self, embed_dim=768, num_outputs=1, patch_size=8, drop=0.1,
|
155 | 156 | load_pretrained_network(self, pretrained_model_path, True, weight_keys='params')
|
156 | 157 | # load_pretrained_network(self, pretrained_model_path, True, )
|
157 | 158 | elif pretrained:
|
158 |
| - load_pretrained_network(self, default_model_urls['pipal'], True) |
| 159 | + load_pretrained_network(self, default_model_urls[train_dataset], True) |
159 | 160 |
|
160 | 161 | def extract_feature(self, save_output):
|
161 | 162 | x6 = save_output.outputs[6][:, 1:]
|
|
0 commit comments