-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
66 lines (49 loc) · 1.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import math
def get_action(mu, std):
action = torch.normal(mu, std)
action = action.data.numpy()
return action
def log_density(x, mu, std, logstd):
var = std.pow(2)
log_density = -(x - mu).pow(2) / (2 * var) \
- 0.5 * math.log(2 * math.pi) - logstd
return log_density.sum(1, keepdim=True)
def flat_grad(grads):
grad_flatten = []
for grad in grads:
grad_flatten.append(grad.view(-1))
grad_flatten = torch.cat(grad_flatten)
return grad_flatten
def flat_hessian(hessians):
hessians_flatten = []
for hessian in hessians:
hessians_flatten.append(hessian.contiguous().view(-1))
hessians_flatten = torch.cat(hessians_flatten).data
return hessians_flatten
def flat_params(model):
params = []
for param in model.parameters():
params.append(param.data.view(-1))
params_flatten = torch.cat(params)
return params_flatten
def update_model(model, new_params):
index = 0
for params in model.parameters():
params_length = len(params.view(-1))
new_param = new_params[index: index + params_length]
new_param = new_param.view(params.size())
params.data.copy_(new_param)
index += params_length
def kl_divergence(new_actor, old_actor, states):
mu, std, logstd = new_actor(torch.Tensor(states))
mu_old, std_old, logstd_old = old_actor(torch.Tensor(states))
mu_old = mu_old.detach()
std_old = std_old.detach()
logstd_old = logstd_old.detach()
# kl divergence between old policy and new policy : D( pi_old || pi_new )
# pi_old -> mu0, logstd0, std0 / pi_new -> mu, logstd, std
# be careful of calculating KL-divergence. It is not symmetric metric
kl = logstd_old - logstd + (std_old.pow(2) + (mu_old - mu).pow(2)) / \
(2.0 * std.pow(2)) - 0.5
return kl.sum(1, keepdim=True)