11
11
create_dataset ,
12
12
create_dataset_temporal ,
13
13
create_iterable_dataloader ,
14
+ list_test_sets ,
14
15
)
15
16
from models import create_model
16
17
from util .parser import get_opt
@@ -27,69 +28,74 @@ def launch_testing(opt, main_opt):
27
28
torch .cuda .set_device (opt .gpu_ids [rank ])
28
29
opt .isTrain = False
29
30
30
- testset = create_dataset (opt , phase = "test" )
31
- print ("The number of testing images = %d" % len (testset ))
32
- opt .num_test_images = len (testset )
33
- opt .train_nb_img_max_fid = min (opt .train_nb_img_max_fid , len (testset ))
34
-
35
- dataloader_test = create_dataloader (
36
- opt , rank , testset , batch_size = opt .test_batch_size
37
- ) # create a dataset given opt.dataset_mode and other options
38
-
39
- use_temporal = ("temporal" in opt .D_netDs ) or opt .train_temporal_criterion
40
-
41
- if use_temporal :
42
- testset_temporal = create_dataset_temporal (opt , phase = "test" )
43
-
44
- dataloader_test_temporal = create_iterable_dataloader (
45
- opt , rank , testset_temporal , batch_size = opt .test_batch_size
46
- )
47
- else :
48
- dataloader_test_temporal = None
49
-
50
31
model = create_model (opt , rank ) # create a model given opt.model and other options
51
32
model .setup (opt ) # regular setup: load and print networks; create schedulers
52
33
53
- # sampling options
54
- if main_opt .sampling_steps is not None :
55
- model .netG_A .denoise_fn .model .beta_schedule ["test" ][
56
- "n_timestep"
57
- ] = main_opt .sampling_steps
58
- if main .opt .model_type == "palette" :
59
- set_new_noise_schedule (model .netG_A .denoise_fn .model , "test" )
60
- if main_opt .sampling_method is not None :
61
- model .netG_A .set_new_sampling_method (main_opt .sampling_method )
62
- if main_opt .ddim_num_steps is not None :
63
- model .ddim_num_steps = main_opt .ddim_num_steps
64
- if main_opt .ddim_eta is not None :
65
- model .ddim_eta = main_opt .ddim_eta
66
-
67
- model .use_temporal = use_temporal
68
- model .eval ()
69
- if opt .use_cuda :
70
- model .single_gpu ()
71
- model .init_metrics (dataloader_test )
72
-
73
- if use_temporal :
74
- dataloaders_test = zip (dataloader_test , dataloader_test_temporal )
75
- else :
76
- dataloaders_test = zip (dataloader_test )
77
-
78
- epoch = "test"
79
- total_iters = "test"
80
- with torch .no_grad ():
81
- model .compute_metrics_test (dataloaders_test , epoch , total_iters )
82
-
83
- metrics = model .get_current_metrics (["" ])
84
- for metric , value in metrics .items ():
85
- print (f"{ metric } : { value } " )
86
-
87
- metrics_dir = os .path .join (opt .test_model_dir , "metrics" )
88
- os .makedirs (metrics_dir , exist_ok = True )
89
- metrics_file = os .path .join (metrics_dir , time .strftime ("%Y%m%d-%H%M%S" ) + ".json" )
90
- with open (metrics_file , "w" ) as f :
91
- f .write (json .dumps (metrics , indent = 4 ))
92
- print ("metrics written to:" , metrics_file )
34
+ all_test_sets = list_test_sets (opt )
35
+
36
+ for test_set in all_test_sets :
37
+ testset = create_dataset (opt , phase = "test" , name = test_set )
38
+ print ("The number of testing images = %d" % len (testset ))
39
+ opt .num_test_images = len (testset )
40
+ opt .train_nb_img_max_fid = min (opt .train_nb_img_max_fid , len (testset ))
41
+
42
+ dataloader_test = create_dataloader (
43
+ opt , rank , testset , batch_size = opt .test_batch_size
44
+ ) # create a dataset given opt.dataset_mode and other options
45
+
46
+ use_temporal = ("temporal" in opt .D_netDs ) or opt .train_temporal_criterion
47
+
48
+ if use_temporal :
49
+ testset_temporal = create_dataset_temporal (opt , phase = "test" )
50
+
51
+ dataloader_test_temporal = create_iterable_dataloader (
52
+ opt , rank , testset_temporal , batch_size = opt .test_batch_size
53
+ )
54
+ else :
55
+ dataloader_test_temporal = None
56
+
57
+ # sampling options
58
+ if main_opt .sampling_steps is not None :
59
+ model .netG_A .denoise_fn .model .beta_schedule ["test" ][
60
+ "n_timestep"
61
+ ] = main_opt .sampling_steps
62
+ if main .opt .model_type == "palette" :
63
+ set_new_noise_schedule (model .netG_A .denoise_fn .model , "test" )
64
+ if main_opt .sampling_method is not None :
65
+ model .netG_A .set_new_sampling_method (main_opt .sampling_method )
66
+ if main_opt .ddim_num_steps is not None :
67
+ model .ddim_num_steps = main_opt .ddim_num_steps
68
+ if main_opt .ddim_eta is not None :
69
+ model .ddim_eta = main_opt .ddim_eta
70
+
71
+ model .use_temporal = use_temporal
72
+ model .eval ()
73
+ if opt .use_cuda :
74
+ model .single_gpu ()
75
+ model .init_metrics (dataloader_test )
76
+
77
+ if use_temporal :
78
+ dataloaders_test = zip (dataloader_test , dataloader_test_temporal )
79
+ else :
80
+ dataloaders_test = zip (dataloader_test )
81
+
82
+ epoch = "test"
83
+ total_iters = "test"
84
+ with torch .no_grad ():
85
+ model .compute_metrics_test (dataloaders_test , epoch , total_iters )
86
+
87
+ metrics = model .get_current_metrics (["" ])
88
+ for metric , value in metrics .items ():
89
+ print (f"{ metric } : { value } " )
90
+
91
+ metrics_dir = os .path .join (opt .test_model_dir , "metrics" )
92
+ os .makedirs (metrics_dir , exist_ok = True )
93
+ metrics_file = os .path .join (
94
+ metrics_dir , time .strftime ("%Y%m%d-%H%M%S" ) + ".json"
95
+ )
96
+ with open (metrics_file , "w" ) as f :
97
+ f .write (json .dumps (metrics , indent = 4 ))
98
+ print ("metrics written to:" , metrics_file )
93
99
94
100
95
101
if __name__ == "__main__" :
@@ -108,7 +114,7 @@ def launch_testing(opt, main_opt):
108
114
"--test_metrics_list" ,
109
115
type = str ,
110
116
nargs = "*" ,
111
- choices = ["FID" , "KID" , "MSID" , "PSNR" , "LPIPS" ],
117
+ choices = ["FID" , "KID" , "MSID" , "PSNR" , "LPIPS" , "SSIM" ],
112
118
default = ["FID" , "KID" , "MSID" , "PSNR" , "LPIPS" ],
113
119
)
114
120
main_parser .add_argument (
0 commit comments