-
Notifications
You must be signed in to change notification settings - Fork 85
/
evaluate.py
102 lines (77 loc) · 3.82 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from argparse import ArgumentParser
import torch
from tqdm import tqdm
from fiery.data import prepare_dataloaders
from fiery.trainer import TrainingModule
from fiery.metrics import IntersectionOverUnion, PanopticMetric
from fiery.utils.network import preprocess_batch
from fiery.utils.instance import predict_instance_segmentation_and_trajectories
# 30mx30m, 100mx100m
EVALUATION_RANGES = {'30x30': (70, 130),
'100x100': (0, 200)
}
def eval(checkpoint_path, dataroot, version):
trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True)
print(f'Loaded weights from \n {checkpoint_path}')
trainer.eval()
device = torch.device('cuda:0')
trainer.to(device)
model = trainer.model
cfg = model.cfg
cfg.GPUS = "[0]"
cfg.BATCHSIZE = 1
cfg.DATASET.DATAROOT = dataroot
cfg.DATASET.VERSION = version
_, valloader = prepare_dataloaders(cfg)
panoptic_metrics = {}
iou_metrics = {}
n_classes = len(cfg.SEMANTIC_SEG.WEIGHTS)
for key in EVALUATION_RANGES.keys():
panoptic_metrics[key] = PanopticMetric(n_classes=n_classes, temporally_consistent=True).to(
device)
iou_metrics[key] = IntersectionOverUnion(n_classes).to(device)
for i, batch in enumerate(tqdm(valloader)):
preprocess_batch(batch, device)
image = batch['image']
intrinsics = batch['intrinsics']
extrinsics = batch['extrinsics']
future_egomotion = batch['future_egomotion']
batch_size = image.shape[0]
labels, future_distribution_inputs = trainer.prepare_future_labels(batch)
with torch.no_grad():
# Evaluate with mean prediction
noise = torch.zeros((batch_size, 1, model.latent_dim), device=device)
output = model(image, intrinsics, extrinsics, future_egomotion,
future_distribution_inputs, noise=noise)
# Consistent instance seg
pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
output, compute_matched_centers=False, make_consistent=True
)
segmentation_pred = output['segmentation'].detach()
segmentation_pred = torch.argmax(segmentation_pred, dim=2, keepdims=True)
for key, grid in EVALUATION_RANGES.items():
limits = slice(grid[0], grid[1])
panoptic_metrics[key](pred_consistent_instance_seg[..., limits, limits].contiguous().detach(),
labels['instance'][..., limits, limits].contiguous()
)
iou_metrics[key](segmentation_pred[..., limits, limits].contiguous(),
labels['segmentation'][..., limits, limits].contiguous()
)
results = {}
for key, grid in EVALUATION_RANGES.items():
panoptic_scores = panoptic_metrics[key].compute()
for panoptic_key, value in panoptic_scores.items():
results[f'{panoptic_key}'] = results.get(f'{panoptic_key}', []) + [100 * value[1].item()]
iou_scores = iou_metrics[key].compute()
results['iou'] = results.get('iou', []) + [100 * iou_scores[1].item()]
for panoptic_key in ['iou', 'pq', 'sq', 'rq']:
print(panoptic_key)
print(' & '.join([f'{x:.1f}' for x in results[panoptic_key]]))
if __name__ == '__main__':
parser = ArgumentParser(description='Fiery evaluation')
parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint')
parser.add_argument('--dataroot', default='./nuscenes', type=str, help='path to the dataset')
parser.add_argument('--version', default='trainval', type=str, choices=['mini', 'trainval'],
help='dataset version')
args = parser.parse_args()
eval(args.checkpoint, args.dataroot, args.version)