-
Notifications
You must be signed in to change notification settings - Fork 24
/
utils.py
82 lines (71 loc) · 3.74 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from itertools import repeat
import numpy as np
import pandas as pd
SIMPLE_FEATURE_COLUMNS = ['ncl[0]', 'ncl[1]', 'ncl[2]', 'ncl[3]', 'avg_cs[0]',
'avg_cs[1]', 'avg_cs[2]', 'avg_cs[3]', 'ndof', 'MatchedHit_TYPE[0]',
'MatchedHit_TYPE[1]', 'MatchedHit_TYPE[2]', 'MatchedHit_TYPE[3]',
'MatchedHit_X[0]', 'MatchedHit_X[1]', 'MatchedHit_X[2]',
'MatchedHit_X[3]', 'MatchedHit_Y[0]', 'MatchedHit_Y[1]',
'MatchedHit_Y[2]', 'MatchedHit_Y[3]', 'MatchedHit_Z[0]',
'MatchedHit_Z[1]', 'MatchedHit_Z[2]', 'MatchedHit_Z[3]',
'MatchedHit_DX[0]', 'MatchedHit_DX[1]', 'MatchedHit_DX[2]',
'MatchedHit_DX[3]', 'MatchedHit_DY[0]', 'MatchedHit_DY[1]',
'MatchedHit_DY[2]', 'MatchedHit_DY[3]', 'MatchedHit_DZ[0]',
'MatchedHit_DZ[1]', 'MatchedHit_DZ[2]', 'MatchedHit_DZ[3]',
'MatchedHit_T[0]', 'MatchedHit_T[1]', 'MatchedHit_T[2]',
'MatchedHit_T[3]', 'MatchedHit_DT[0]', 'MatchedHit_DT[1]',
'MatchedHit_DT[2]', 'MatchedHit_DT[3]', 'Lextra_X[0]', 'Lextra_X[1]',
'Lextra_X[2]', 'Lextra_X[3]', 'Lextra_Y[0]', 'Lextra_Y[1]',
'Lextra_Y[2]', 'Lextra_Y[3]', 'NShared', 'Mextra_DX2[0]',
'Mextra_DX2[1]', 'Mextra_DX2[2]', 'Mextra_DX2[3]', 'Mextra_DY2[0]',
'Mextra_DY2[1]', 'Mextra_DY2[2]', 'Mextra_DY2[3]', 'FOI_hits_N', 'PT', 'P']
TRAIN_COLUMNS = ["label", "weight"]
FOI_COLUMNS = ["FOI_hits_X", "FOI_hits_Y", "FOI_hits_T",
"FOI_hits_Z", "FOI_hits_DX", "FOI_hits_DY", "FOI_hits_S"]
ID_COLUMN = "id"
N_STATIONS = 4
FEATURES_PER_STATION = 6
N_FOI_FEATURES = N_STATIONS*FEATURES_PER_STATION
# The value to use for stations with missing hits
# when computing FOI features
EMPTY_FILLER = 1000
def parse_array(line, dtype=np.float32):
return np.fromstring(line[1:-1], sep=" ", dtype=dtype)
def load_full_test_csv(path):
converters = dict(zip(FOI_COLUMNS, repeat(parse_array)))
types = dict(zip(SIMPLE_FEATURE_COLUMNS, repeat(np.float32)))
test = pd.read_csv(os.path.join(path, "test_public_%s.csv.gz" % VERSION),
index_col="id", converters=converters,
dtype=types,
usecols=[ID_COLUMN]+SIMPLE_FEATURE_COLUMNS+FOI_COLUMNS)
return test
def find_closest_hit_per_station(row):
result = np.empty(N_FOI_FEATURES, dtype=np.float32)
closest_x_per_station = result[0:4]
closest_y_per_station = result[4:8]
closest_T_per_station = result[8:12]
closest_z_per_station = result[12:16]
closest_dx_per_station = result[16:20]
closest_dy_per_station = result[20:24]
for station in range(4):
hits = (row["FOI_hits_S"] == station)
if not hits.any():
closest_x_per_station[station] = EMPTY_FILLER
closest_y_per_station[station] = EMPTY_FILLER
closest_T_per_station[station] = EMPTY_FILLER
closest_z_per_station[station] = EMPTY_FILLER
closest_dx_per_station[station] = EMPTY_FILLER
closest_dy_per_station[station] = EMPTY_FILLER
else:
x_distances_2 = (row["Lextra_X[%i]" % station] - row["FOI_hits_X"][hits])**2
y_distances_2 = (row["Lextra_Y[%i]" % station] - row["FOI_hits_Y"][hits])**2
distances_2 = x_distances_2 + y_distances_2
closest_hit = np.argmin(distances_2)
closest_x_per_station[station] = x_distances_2[closest_hit]
closest_y_per_station[station] = y_distances_2[closest_hit]
closest_T_per_station[station] = row["FOI_hits_T"][hits][closest_hit]
closest_z_per_station[station] = row["FOI_hits_Z"][hits][closest_hit]
closest_dx_per_station[station] = row["FOI_hits_DX"][hits][closest_hit]
closest_dy_per_station[station] = row["FOI_hits_DY"][hits][closest_hit]
return result