-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathutils.py
52 lines (41 loc) · 1.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding: utf-8 -*-
import plotly
from plotly.graph_objs import Scatter, Line
import torch
from torch import multiprocessing as mp
# Global counter
class Counter():
def __init__(self):
self.val = mp.Value('i', 0)
self.lock = mp.Lock()
def increment(self):
with self.lock:
self.val.value += 1
def value(self):
with self.lock:
return self.val.value
# Converts a state from the OpenAI Gym (a numpy array) to a batch tensor
def state_to_tensor(state):
return torch.from_numpy(state).float().unsqueeze(0)
# Plots min, max and mean + standard deviation bars of a population over time
def plot_line(xs, ys_population):
max_colour = 'rgb(0, 132, 180)'
mean_colour = 'rgb(0, 172, 237)'
std_colour = 'rgba(29, 202, 255, 0.2)'
ys = torch.Tensor(ys_population)
ys_min = ys.min(1)[0].squeeze()
ys_max = ys.max(1)[0].squeeze()
ys_mean = ys.mean(1).squeeze()
ys_std = ys.std(1).squeeze()
ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std
trace_max = Scatter(x=xs, y=ys_max.numpy(), line=Line(color=max_colour, dash='dash'), name='Max')
trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color='transparent'), name='+1 Std. Dev.', showlegend=False)
trace_mean = Scatter(x=xs, y=ys_mean.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean')
trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color='transparent'), name='-1 Std. Dev.', showlegend=False)
trace_min = Scatter(x=xs, y=ys_min.numpy(), line=Line(color=max_colour, dash='dash'), name='Min')
plotly.offline.plot({
'data': [trace_upper, trace_mean, trace_lower, trace_min, trace_max],
'layout': dict(title='Rewards',
xaxis={'title': 'Step'},
yaxis={'title': 'Average Reward'})
}, filename='rewards.html', auto_open=False)