-
Notifications
You must be signed in to change notification settings - Fork 35
/
logging.py
66 lines (53 loc) · 2.08 KB
/
logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from collections import defaultdict
import logging
import numpy as np
class Logger:
def __init__(self, console_logger):
self.console_logger = console_logger
self.use_tb = False
self.use_sacred = False
self.use_hdf = False
self.stats = defaultdict(lambda: [])
def setup_tb(self, directory_name):
# Import here so it doesn't have to be installed if you don't use it
from tensorboard_logger import configure, log_value
configure(directory_name)
self.tb_logger = log_value
self.use_tb = True
def setup_sacred(self, sacred_run_dict):
self.sacred_info = sacred_run_dict.info
self.use_sacred = True
# TODO: Setup hdf logger
def log_stat(self, key, value, t, to_sacred=True):
self.stats[key].append((t, value))
if self.use_tb:
self.tb_logger(key, value, t)
if self.use_sacred and to_sacred:
if key in self.sacred_info:
self.sacred_info["{}_T".format(key)].append(t)
self.sacred_info[key].append(value)
else:
self.sacred_info["{}_T".format(key)] = [t]
self.sacred_info[key] = [value]
def print_recent_stats(self):
log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(*self.stats["episode"][-1])
i = 0
for (k, v) in sorted(self.stats.items()):
if k == "episode":
continue
i += 1
window = 3 if (k not in ["epsilon"]) else 1
item = "{:.4f}".format(np.mean([x[1] for x in self.stats[k][-window:]]))
log_str += "{:<25}{:>8}".format(k + ":", item)
log_str += "\n" if i % 4 == 0 else "\t"
self.console_logger.info(log_str)
# set up a custom logger
def get_logger():
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S')
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('DEBUG')
return logger