-
Notifications
You must be signed in to change notification settings - Fork 19
/
test_runs.py
89 lines (81 loc) · 3.48 KB
/
test_runs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import time
import os
import sys
import subprocess
import shlex
import glob
import torch
import argparse
# wrapper to run test experiments
# argparse model checkpoint
parser = argparse.ArgumentParser('Model Test Pipeline')
parser.add_argument('checkpoint_dir', help='directory of experiment checkpoints')
parser.add_argument('dataset_type', help='gen_models, faceforensics, etc')
parser.add_argument('partition', help='which partition to run [val|test]')
args = parser.parse_args()
checkpoint_dir = args.checkpoint_dir
checkpoints = glob.glob(os.path.join(checkpoint_dir, '*_net_D.pth'))
def get_dataset_paths(dataroot, datasets, partition):
fake_datasets = [os.path.join(dataroot, dataset[0], partition)
for dataset in datasets]
real_datasets = [os.path.join(dataroot, dataset[1], partition)
for dataset in datasets]
dataset_names = [dataset[2] for dataset in datasets]
return fake_datasets, real_datasets, dataset_names
# datasets
if args.dataset_type == 'gen_models':
dataroot = 'dataset/faces/'
partition = args.partition
datasets = [
('celebahq/pgan-pretrained-128-png',
'celebahq/real-tfr-1024-resized128', 'celebahq-pgan-pretrained'),
('celebahq/sgan-pretrained-128-png',
'celebahq/real-tfr-1024-resized128', 'celebahq-sgan-pretrained'),
('celebahq/glow-pretrained-128-png',
'celebahq/real-tfr-1024-resized128', 'celebahq-glow-pretrained'),
('celeba/mfa-defaults', 'celeba/mfa-real', 'celeba-gmm'),
('ffhq/pgan-9k-128-png', 'ffhq/real-tfr-1024-resized128', 'ffhq-pgan'),
('ffhq/sgan-pretrained-128-png', 'ffhq/real-tfr-1024-resized128',
'ffhq-sgan'),
('ffhq/sgan2-pretrained-128-png', 'ffhq/real-tfr-1024-resized128',
'ffhq-sgan2'),
]
fake_datasets, real_datasets, dataset_names = get_dataset_paths(
dataroot, datasets, partition)
elif args.dataset_type == 'faceforensics':
dataroot = 'dataset/faces/'
partition = args.partition
datasets = [
('faceforensics_aligned/NeuralTextures/manipulated',
'faceforensics_aligned/NeuralTextures/original', 'NT'),
('faceforensics_aligned/Deepfakes/manipulated',
'faceforensics_aligned/Deepfakes/original', 'DF'),
('faceforensics_aligned/Face2Face/manipulated',
'faceforensics_aligned/Face2Face/original', 'F2F'),
('faceforensics_aligned/FaceSwap/manipulated',
'faceforensics_aligned/FaceSwap/original', 'FS'),
]
fake_datasets, real_datasets, dataset_names = get_dataset_paths(
dataroot, datasets, partition)
else:
raise NotImplementedError
# print the datasets to test on
print(real_datasets)
print(fake_datasets)
print(dataset_names)
for checkpoint in checkpoints:
for fake, real, name in zip(fake_datasets, real_datasets, dataset_names):
which_epoch = os.path.basename(checkpoint).split('_')[0]
if which_epoch != 'bestval':
# only runs using the bestval checkpoint
continue
test_command = ('python test.py --train_config %s' %
(os.path.join(checkpoint_dir, 'opt.yml')))
test_command += ' --which_epoch %s' % which_epoch
test_command += ' --gpu_ids 0'
test_command += ' --real_im_path %s' % real
test_command += ' --fake_im_path %s' % fake
test_command += ' --partition %s' % args.partition
test_command += ' --dataset_name %s' % name
print(test_command)
os.system(test_command)