-
Notifications
You must be signed in to change notification settings - Fork 18
/
utils.py
137 lines (104 loc) · 3.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np
# logsumexp() and expit() are used because they are
# numerically stable
# expit() is the sigmoid function
from scipy.misc import logsumexp
from scipy.special import expit
class EpsGreedyPolicy():
def __init__(self, rng, nstates, noptions, epsilon):
self.rng = rng
self.nstates = nstates
self.noptions = noptions
self.epsilon = epsilon
self.Q_Omega_table = np.zeros((nstates, noptions))
def Q_Omega(self, state, option=None):
if option is None:
return self.Q_Omega_table[state,:]
else:
return self.Q_Omega_table[state, option]
def sample(self, state):
if self.rng.uniform() < self.epsilon:
return int(self.rng.randint(self.noptions))
else:
return int(np.argmax(self.Q_Omega(state)))
class SoftmaxPolicy():
def __init__(self, rng, lr, nstates, nactions, temperature=1.0):
self.rng = rng
self.lr = lr
self.nstates = nstates
self.nactions = nactions
self.temperature = temperature
self.weights = np.zeros((nstates, nactions))
def Q_U(self, state, action=None):
if action is None:
return self.weights[state,:]
else:
return self.weights[state, action]
def pmf(self, state):
exponent = self.Q_U(state) / self.temperature
return np.exp(exponent - logsumexp(exponent))
def sample(self, state):
return int(self.rng.choice(self.nactions, p=self.pmf(state)))
def gradient(self):
pass
def update(self, state, action, Q_U):
actions_pmf = self.pmf(state)
self.weights[state, :] -= self.lr * actions_pmf * Q_U
self.weights[state, action] += self.lr * Q_U
class SigmoidTermination():
def __init__(self, rng, lr, nstates):
self.rng = rng
self.lr = lr
self.nstates = nstates
self.weights = np.zeros((nstates,))
def pmf(self, state):
return expit(self.weights[state])
def sample(self, state):
return int(self.rng.uniform() < self.pmf(state))
def gradient(self, state):
return self.pmf(state) * (1.0 - self.pmf(state)), state
def update(self, state, advantage):
magnitude, direction = self.gradient(state)
self.weights[direction] -= self.lr * magnitude * advantage
class Critic():
def __init__(self, lr, discount, Q_Omega_table, nstates, noptions, nactions):
self.lr = lr
self.discount = discount
self.Q_Omega_table = Q_Omega_table
self.Q_U_table = np.zeros((nstates, noptions, nactions))
def cache(self, state, option, action):
self.last_state = state
self.last_option = option
self.last_action = action
self.last_Q_Omega = self.Q_Omega(state, option)
def Q_Omega(self, state, option=None):
if option is None:
return self.Q_Omega_table[state, :]
else:
return self.Q_Omega_table[state, option]
def Q_U(self, state, option, action):
return self.Q_U_table[state, option, action]
def A_Omega(self, state, option=None):
advantage = self.Q_Omega(state) - np.max(self.Q_Omega(state))
if option is None:
return advantage
else:
return advantage[option]
def update_Qs(self, state, option, action, reward, done, terminations):
# One step target for Q_Omega
target = reward
if not done:
beta_omega = terminations[self.last_option].pmf(state)
target += self.discount * ((1.0 - beta_omega)*self.Q_Omega(state, self.last_option) + \
beta_omega*np.max(self.Q_Omega(state)))
# Difference update
tderror_Q_Omega = target - self.last_Q_Omega
self.Q_Omega_table[self.last_state, self.last_option] += self.lr * tderror_Q_Omega
tderror_Q_U = target - self.Q_U(self.last_state, self.last_option, self.last_action)
self.Q_U_table[self.last_state, self.last_option, self.last_action] += self.lr * tderror_Q_U
# Cache
self.last_state = state
self.last_option = option
self.last_action = action
if not done:
self.last_Q_Omega = self.Q_Omega(state, option)