-
Notifications
You must be signed in to change notification settings - Fork 10
/
inference.py
83 lines (72 loc) · 3.03 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import os
import time
from pathlib import Path
import gin
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from core.raft import RAFT
from datasets import get_test_data_loader
from utils.data_utils import crop_operation, scale_operation
from utils.frame_utils import write_pfm
@gin.configurable()
def inference(
test_loader,
ckpt,
output_folder,
rescale=1,
crop=None,
do_report=False,
write_min_depth=None,
):
model = RAFT(test_mode=True).cuda()
if ckpt is not None:
tmp = torch.load(ckpt)
if list(tmp.keys())[0][:7] == "module.":
model = nn.DataParallel(model)
model.load_state_dict(tmp, strict=True)
model.eval()
output_folder = Path(output_folder)
(output_folder / "depths").mkdir(exist_ok=True, parents=True)
with torch.no_grad():
for images, poses, intrinsics, image_names, scale in tqdm(test_loader):
poses = poses.cuda()
images = images.squeeze(0)
intrinsics = intrinsics.squeeze(0)
images, intrinsics = scale_operation(images, intrinsics, rescale)
if not crop is None:
crop_h, crop_w = crop
images, intrinsics = crop_operation(images, intrinsics, crop_h, crop_w)
images = images.unsqueeze(0).cuda()
intrinsics = intrinsics.unsqueeze(0).cuda()
if do_report:
tic = time.time()
disp_est = model(images, poses, intrinsics, do_report=do_report, scale=scale)
if do_report:
print(f"per view time: {time.time() - tic}")
res = disp_est.cpu().numpy()[0, 0]
im = np.where(res == 0, 0, 1 / res).astype(np.float32)
write_pfm(output_folder / "depths" / f"{image_names[0][0]}_scale{rescale}_nf{test_loader.dataset.num_frames}.pfm", im)
if write_min_depth is not None:
write_min_depth = Path(write_min_depth)
write_min_depth.mkdir(exist_ok=True)
with open(write_min_depth / f"{image_names[0][0]}.txt", "w") as f:
min_depth = np.quantile(im[im > 0], 0.1) / 2
f.write(f"{min_depth}\n")
torch.cuda.empty_cache()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gin_config', nargs='+', default=['inference_DTU'],
help='Set of config files for gin (separated by spaces) '
'e.g. --gin_config file1 file2 (exclude .gin from path)')
parser.add_argument('-p', '--gin_param', nargs='+', default=[],
help='Parameter settings that override config defaults '
'e.g. --gin_param module_1.a=2 module_2.b=3')
args = parser.parse_args()
gin_files = [f'configs/{g}.gin' for g in args.gin_config]
gin.parse_config_files_and_bindings(
gin_files, args.gin_param, skip_unknown=True)
test_loader = get_test_data_loader()
inference(test_loader)