-
Notifications
You must be signed in to change notification settings - Fork 7
/
visual_utils.py
91 lines (76 loc) · 2.42 KB
/
visual_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pandas as pd
from pandas.api.types import is_string_dtype, is_numeric_dtype
import logging
import os
import os.path as osp
import numpy as np
import json
logger = logging.getLogger(__name__)
def _flatten_dict(dt):
while any(type(v) is dict for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if type(value) is dict:
for subkey, v in value.items():
add[":".join([key, subkey])] = v
remove.append(key)
dt.update(add)
for k in remove:
del dt[k]
return dt
def _parse_results(res_path):
res_dict = {}
try:
with open(res_path) as f:
# Get last line in file
for line in f:
pass
res_dict = _flatten_dict(json.loads(line.strip()))
except Exception:
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
return res_dict
def _parse_configs(cfg_path):
try:
with open(cfg_path) as f:
cfg_dict = _flatten_dict(json.load(f))
except Exception:
logger.exception("Config parsing failed.")
return cfg_dict
def _resolve(directory, result_fname):
try:
resultp = osp.join(directory, result_fname)
res_dict = _parse_results(resultp)
cfgp = osp.join(directory, "params.json")
cfg_dict = _parse_configs(cfgp)
cfg_dict.update(res_dict)
return cfg_dict
except Exception:
return None
def load_results_to_df(directory, result_name="result.json"):
exp_directories = [
dirpath for dirpath, dirs, files in os.walk(directory) for f in files
if f == result_name
]
data = [_resolve(d, result_name) for d in exp_directories]
data = [d for d in data if d]
return pd.DataFrame(data)
def generate_plotly_dim_dict(df, field):
dim_dict = {}
dim_dict["label"] = field
column = df[field]
if is_numeric_dtype(column):
dim_dict["values"] = column
elif is_string_dtype(column):
texts = column.unique()
dim_dict["values"] = [
np.argwhere(texts == x).flatten()[0] for x in column
]
dim_dict["tickvals"] = list(range(len(texts)))
dim_dict["ticktext"] = texts
else:
raise Exception("Unidentifiable Type")
return dim_dict