-
Notifications
You must be signed in to change notification settings - Fork 9
/
DQRC.py
120 lines (94 loc) · 5.02 KB
/
DQRC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import numpy as np
import torch.nn.functional as f
from TDRC.utils import getBatchColumns
class DQRC:
def __init__(self, features, actions, policy_net, target_net, optimizer, params, device=None):
self.features = features
self.actions = actions
self.params = params
self.device = device
self.policy_net = policy_net
self.target_net = target_net
self.optimizer = optimizer
# regularization parameter
self.alpha = params['alpha']
self.epsilon = params['epsilon']
self.beta = params['beta']
# secondary weights optimization parameters
self.beta_1 = params.get('beta_1', 0.99)
self.beta_2 = params.get('beta_2', 0.999)
self.eps = params.get('eps', 1e-8)
# learnable parameters for secondary weights
self.h = torch.zeros(self.actions, features, requires_grad=False).to(device)
# ADAM optimizer parameters for secondary weights
self.v = torch.zeros(self.actions, features, requires_grad=False).to(device)
self.m = torch.zeros(self.actions, features, requires_grad=False).to(device)
def selectAction(self, x):
# take a random action about epsilon percent of the time
if np.random.rand() < self.epsilon:
a = np.random.randint(self.actions)
return a
# otherwise take a greedy action
q_s, _ = self.policy_net(x)
return q_s.argmax().detach().cpu().numpy()
def updateNetwork(self, samples):
# organize the mini-batch so that we can request "columns" from the data
# e.g. we can get all of the actions, or all of the states with a single call
batch = getBatchColumns(samples)
# compute Q(s, a) for each sample in mini-batch
Qs, x = self.policy_net(batch.states)
Qsa = Qs.gather(1, batch.actions).squeeze()
# by default Q(s', a') = 0 unless the next states are non-terminal
Qspap = torch.zeros(batch.size, device=self.device)
# if we don't have any non-terminal next states, then no need to bootstrap
if batch.nterm_sp.shape[0] > 0:
Qsp, _ = self.target_net(batch.nterm_sp)
# bootstrapping term is the max Q value for the next-state
# only assign to indices where the next state is non-terminal
Qspap[batch.nterm] = Qsp.max(1).values
# compute the empirical MSBE for this mini-batch and let torch auto-diff to optimize
# don't worry about detaching the bootstrapping term for semi-gradient Q-learning
# the target network handles that
target = batch.rewards + batch.gamma * Qspap.detach()
td_loss = 0.5 * f.mse_loss(target, Qsa)
# compute E[\delta | x] ~= <h, x>
with torch.no_grad():
delta_hats = torch.matmul(x, self.h.t())
delta_hat = delta_hats.gather(1, batch.actions)
# the gradient correction term is gamma * <h, x> * \nabla_w Q(s', a')
# to compute this gradient, we use pytorch auto-diff
correction_loss = torch.mean(batch.gamma * delta_hat * Qspap)
# make sure we have no gradients left over from previous update
self.optimizer.zero_grad()
self.target_net.zero_grad()
# compute the entire gradient of the network using only the td error
td_loss.backward()
# if we have non-terminal states in the mini-batch
# the compute the correction term using the gradient of the *target network*
if batch.nterm_sp.shape[0] > 0:
correction_loss.backward()
# add the gradients of the target network for the correction term to the gradients for the td error
for (policy_param, target_param) in zip(self.policy_net.parameters(), self.target_net.parameters()):
policy_param.grad.add_(target_param.grad)
# update the *policy network* using the combined gradients
self.optimizer.step()
# update the secondary weights using a *fixed* feature representation generated by the policy network
with torch.no_grad():
delta = target - Qsa
dh = (delta - delta_hat) * x
# compute the update for each action independently
# assume that there is a separate `h` vector for each individual action
for a in range(self.actions):
mask = (batch.actions == a).squeeze(1)
# if this action was never taken in this mini-batch
# then skip the update for this action
if mask.sum() == 0:
continue
# the update for `h` minus the regularizer
h_update = dh[mask].mean(0) - self.beta * self.h[a]
# ADAM optimizer
# keep a separate set of weights for each action here as well
self.v[a] = self.beta_2 * self.v[a] + (1 - self.beta_2) * (h_update**2)
self.m[a] = self.beta_1 * self.m[a] + (1 - self.beta_1) * h_update
self.h[a] = self.h[a] + self.alpha * self.m[a] / (torch.sqrt(self.v[a]) + self.eps)