-
Notifications
You must be signed in to change notification settings - Fork 8
/
vanila_pg.py
72 lines (52 loc) · 2.02 KB
/
vanila_pg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import torch
from hparams import HyperParams as hp
from utils import log_density
def get_returns(rewards, masks):
rewards = torch.Tensor(rewards)
masks = torch.Tensor(masks)
returns = torch.zeros_like(rewards)
running_returns = 0
for t in reversed(range(0, len(rewards))):
running_returns = rewards[t] + hp.gamma * running_returns * masks[t]
returns[t] = running_returns
returns = (returns - returns.mean()) / returns.std()
return returns
def get_loss(actor, returns, states, actions):
mu, std, logstd = actor(torch.Tensor(states))
log_policy = log_density(torch.Tensor(actions), mu, std, logstd)
returns = returns.unsqueeze(1)
objective = returns * log_policy
objective = objective.mean()
return - objective
def train_critic(critic, states, returns, critic_optim):
criterion = torch.nn.MSELoss()
n = len(states)
arr = np.arange(n)
for epoch in range(5):
np.random.shuffle(arr)
for i in range(n // hp.batch_size):
batch_index = arr[hp.batch_size * i: hp.batch_size * (i + 1)]
batch_index = torch.LongTensor(batch_index)
inputs = torch.Tensor(states)[batch_index]
target = returns.unsqueeze(1)[batch_index]
values = critic(inputs)
loss = criterion(values, target)
critic_optim.zero_grad()
loss.backward()
critic_optim.step()
def train_actor(actor, returns, states, actions, actor_optim):
loss = get_loss(actor, returns, states, actions)
actor_optim.zero_grad()
loss.backward()
actor_optim.step()
def train_model(actor, critic, memory, actor_optim, critic_optim):
memory = np.array(memory)
states = np.vstack(memory[:, 0])
actions = list(memory[:, 1])
rewards = list(memory[:, 2])
masks = list(memory[:, 3])
returns = get_returns(rewards, masks)
train_critic(critic, states, returns, critic_optim)
train_actor(actor, returns, states, actions, actor_optim)
return returns