forked from abdalazizrashid/idao-21-baseline
-
Notifications
You must be signed in to change notification settings - Fork 12
/
generate_submission.py
77 lines (60 loc) · 2.36 KB
/
generate_submission.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import configparser
import gc
import logging
import pathlib as path
import sys
from collections import defaultdict
from itertools import chain
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from idao.data_module import IDAODataModule
from idao.model import SimpleConv
def compute_predictions(mode, dataloader, checkpoint_path, cfg):
torch.multiprocessing.set_sharing_strategy("file_system")
logging.info("Loading checkpoint")
model = SimpleConv.load_from_checkpoint(checkpoint_path, mode=mode)
model = model.cpu().eval()
dict_pred = defaultdict(list)
if mode == "classification":
logging.info("Classification model loaded")
else:
logging.info("Regression model loaded")
for img, name in iter(dataloader):
if mode == "classification":
dict_pred["id"].extend(map(lambda x: x.strip('.png'), name))
output = model(img)["class"].detach()[:, 1].numpy()
dict_pred["particle"].extend(output)
else:
output = model(img)["energy"].detach().squeeze(1).numpy()
dict_pred["energy"].extend(output)
return dict_pred
def main():
config = configparser.ConfigParser()
config.read("./config.ini")
PATH = path.Path(config["DATA"]["DatasetPath"])
dataset_dm = IDAODataModule(
data_dir=PATH, batch_size=512, cfg=config
)
dataset_dm.prepare_data()
print(dataset_dm.dataset.class_to_idx)
dataset_dm.setup()
dl = dataset_dm.test_dataloader()
dict_pred = defaultdict(list)
for mode in ["regression", "classification"]:
if mode == "classification":
model_path = config["REPORT"]["ClassificationCheckpoint"]
else:
model_path = config["REPORT"]["RegressionCheckpoint"]
dict_pred.update(compute_predictions(mode, dl, model_path, cfg=config))
data_frame = pd.DataFrame(dict_pred,
columns=["id", "energy", "particle"])
data_frame.set_index("id", inplace=True)
data_frame.to_csv('submission_classification.csv.gz',
index=True, header=True, index_label="id", columns=["particle"])
data_frame.to_csv('submission_regression.csv.gz',
index=True, header=True, index_label="id", columns=["energy"])
if __name__ == "__main__":
main()