Skip to content

Commit 06762fb

Browse files
committed
fix: multiple test sets with test.py + SSIM
1 parent 1643244 commit 06762fb

File tree

1 file changed

+67
-61
lines changed

1 file changed

+67
-61
lines changed

test.py

+67-61
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
create_dataset,
1212
create_dataset_temporal,
1313
create_iterable_dataloader,
14+
list_test_sets,
1415
)
1516
from models import create_model
1617
from util.parser import get_opt
@@ -27,69 +28,74 @@ def launch_testing(opt, main_opt):
2728
torch.cuda.set_device(opt.gpu_ids[rank])
2829
opt.isTrain = False
2930

30-
testset = create_dataset(opt, phase="test")
31-
print("The number of testing images = %d" % len(testset))
32-
opt.num_test_images = len(testset)
33-
opt.train_nb_img_max_fid = min(opt.train_nb_img_max_fid, len(testset))
34-
35-
dataloader_test = create_dataloader(
36-
opt, rank, testset, batch_size=opt.test_batch_size
37-
) # create a dataset given opt.dataset_mode and other options
38-
39-
use_temporal = ("temporal" in opt.D_netDs) or opt.train_temporal_criterion
40-
41-
if use_temporal:
42-
testset_temporal = create_dataset_temporal(opt, phase="test")
43-
44-
dataloader_test_temporal = create_iterable_dataloader(
45-
opt, rank, testset_temporal, batch_size=opt.test_batch_size
46-
)
47-
else:
48-
dataloader_test_temporal = None
49-
5031
model = create_model(opt, rank) # create a model given opt.model and other options
5132
model.setup(opt) # regular setup: load and print networks; create schedulers
5233

53-
# sampling options
54-
if main_opt.sampling_steps is not None:
55-
model.netG_A.denoise_fn.model.beta_schedule["test"][
56-
"n_timestep"
57-
] = main_opt.sampling_steps
58-
if main.opt.model_type == "palette":
59-
set_new_noise_schedule(model.netG_A.denoise_fn.model, "test")
60-
if main_opt.sampling_method is not None:
61-
model.netG_A.set_new_sampling_method(main_opt.sampling_method)
62-
if main_opt.ddim_num_steps is not None:
63-
model.ddim_num_steps = main_opt.ddim_num_steps
64-
if main_opt.ddim_eta is not None:
65-
model.ddim_eta = main_opt.ddim_eta
66-
67-
model.use_temporal = use_temporal
68-
model.eval()
69-
if opt.use_cuda:
70-
model.single_gpu()
71-
model.init_metrics(dataloader_test)
72-
73-
if use_temporal:
74-
dataloaders_test = zip(dataloader_test, dataloader_test_temporal)
75-
else:
76-
dataloaders_test = zip(dataloader_test)
77-
78-
epoch = "test"
79-
total_iters = "test"
80-
with torch.no_grad():
81-
model.compute_metrics_test(dataloaders_test, epoch, total_iters)
82-
83-
metrics = model.get_current_metrics([""])
84-
for metric, value in metrics.items():
85-
print(f"{metric}: {value}")
86-
87-
metrics_dir = os.path.join(opt.test_model_dir, "metrics")
88-
os.makedirs(metrics_dir, exist_ok=True)
89-
metrics_file = os.path.join(metrics_dir, time.strftime("%Y%m%d-%H%M%S") + ".json")
90-
with open(metrics_file, "w") as f:
91-
f.write(json.dumps(metrics, indent=4))
92-
print("metrics written to:", metrics_file)
34+
all_test_sets = list_test_sets(opt)
35+
36+
for test_set in all_test_sets:
37+
testset = create_dataset(opt, phase="test", name=test_set)
38+
print("The number of testing images = %d" % len(testset))
39+
opt.num_test_images = len(testset)
40+
opt.train_nb_img_max_fid = min(opt.train_nb_img_max_fid, len(testset))
41+
42+
dataloader_test = create_dataloader(
43+
opt, rank, testset, batch_size=opt.test_batch_size
44+
) # create a dataset given opt.dataset_mode and other options
45+
46+
use_temporal = ("temporal" in opt.D_netDs) or opt.train_temporal_criterion
47+
48+
if use_temporal:
49+
testset_temporal = create_dataset_temporal(opt, phase="test")
50+
51+
dataloader_test_temporal = create_iterable_dataloader(
52+
opt, rank, testset_temporal, batch_size=opt.test_batch_size
53+
)
54+
else:
55+
dataloader_test_temporal = None
56+
57+
# sampling options
58+
if main_opt.sampling_steps is not None:
59+
model.netG_A.denoise_fn.model.beta_schedule["test"][
60+
"n_timestep"
61+
] = main_opt.sampling_steps
62+
if main.opt.model_type == "palette":
63+
set_new_noise_schedule(model.netG_A.denoise_fn.model, "test")
64+
if main_opt.sampling_method is not None:
65+
model.netG_A.set_new_sampling_method(main_opt.sampling_method)
66+
if main_opt.ddim_num_steps is not None:
67+
model.ddim_num_steps = main_opt.ddim_num_steps
68+
if main_opt.ddim_eta is not None:
69+
model.ddim_eta = main_opt.ddim_eta
70+
71+
model.use_temporal = use_temporal
72+
model.eval()
73+
if opt.use_cuda:
74+
model.single_gpu()
75+
model.init_metrics(dataloader_test)
76+
77+
if use_temporal:
78+
dataloaders_test = zip(dataloader_test, dataloader_test_temporal)
79+
else:
80+
dataloaders_test = zip(dataloader_test)
81+
82+
epoch = "test"
83+
total_iters = "test"
84+
with torch.no_grad():
85+
model.compute_metrics_test(dataloaders_test, epoch, total_iters)
86+
87+
metrics = model.get_current_metrics([""])
88+
for metric, value in metrics.items():
89+
print(f"{metric}: {value}")
90+
91+
metrics_dir = os.path.join(opt.test_model_dir, "metrics")
92+
os.makedirs(metrics_dir, exist_ok=True)
93+
metrics_file = os.path.join(
94+
metrics_dir, time.strftime("%Y%m%d-%H%M%S") + ".json"
95+
)
96+
with open(metrics_file, "w") as f:
97+
f.write(json.dumps(metrics, indent=4))
98+
print("metrics written to:", metrics_file)
9399

94100

95101
if __name__ == "__main__":
@@ -108,7 +114,7 @@ def launch_testing(opt, main_opt):
108114
"--test_metrics_list",
109115
type=str,
110116
nargs="*",
111-
choices=["FID", "KID", "MSID", "PSNR", "LPIPS"],
117+
choices=["FID", "KID", "MSID", "PSNR", "LPIPS", "SSIM"],
112118
default=["FID", "KID", "MSID", "PSNR", "LPIPS"],
113119
)
114120
main_parser.add_argument(

0 commit comments

Comments
 (0)