-
Notifications
You must be signed in to change notification settings - Fork 728
/
reinforce_agent.py
129 lines (106 loc) · 4.35 KB
/
reinforce_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import copy
import pylab
import numpy as np
from environment import Env
from keras.layers import Dense
from keras.optimizers import Adam
from keras.models import Sequential
from keras import backend as K
EPISODES = 2500
# this is REINFORCE Agent for GridWorld
class ReinforceAgent:
def __init__(self):
self.load_model = True
# actions which agent can do
self.action_space = [0, 1, 2, 3, 4]
# get size of state and action
self.action_size = len(self.action_space)
self.state_size = 15
self.discount_factor = 0.99
self.learning_rate = 0.001
self.model = self.build_model()
self.optimizer = self.optimizer()
self.states, self.actions, self.rewards = [], [], []
if self.load_model:
self.model.load_weights('./save_model/reinforce_trained.h5')
# state is input and probability of each action(policy) is output of network
def build_model(self):
model = Sequential()
model.add(Dense(24, input_dim=self.state_size, activation='relu'))
model.add(Dense(24, activation='relu'))
model.add(Dense(self.action_size, activation='softmax'))
model.summary()
return model
# create error function and training function to update policy network
def optimizer(self):
action = K.placeholder(shape=[None, 5])
discounted_rewards = K.placeholder(shape=[None, ])
# Calculate cross entropy error function
action_prob = K.sum(action * self.model.output, axis=1)
cross_entropy = K.log(action_prob) * discounted_rewards
loss = -K.sum(cross_entropy)
# create training function
optimizer = Adam(lr=self.learning_rate)
updates = optimizer.get_updates(self.model.trainable_weights, [],
loss)
train = K.function([self.model.input, action, discounted_rewards], [],
updates=updates)
return train
# get action from policy network
def get_action(self, state):
policy = self.model.predict(state)[0]
return np.random.choice(self.action_size, 1, p=policy)[0]
# calculate discounted rewards
def discount_rewards(self, rewards):
discounted_rewards = np.zeros_like(rewards)
running_add = 0
for t in reversed(range(0, len(rewards))):
running_add = running_add * self.discount_factor + rewards[t]
discounted_rewards[t] = running_add
return discounted_rewards
# save states, actions and rewards for an episode
def append_sample(self, state, action, reward):
self.states.append(state[0])
self.rewards.append(reward)
act = np.zeros(self.action_size)
act[action] = 1
self.actions.append(act)
# update policy neural network
def train_model(self):
discounted_rewards = np.float32(self.discount_rewards(self.rewards))
discounted_rewards -= np.mean(discounted_rewards)
discounted_rewards /= np.std(discounted_rewards)
self.optimizer([self.states, self.actions, discounted_rewards])
self.states, self.actions, self.rewards = [], [], []
if __name__ == "__main__":
env = Env()
agent = ReinforceAgent()
global_step = 0
scores, episodes = [], []
for e in range(EPISODES):
done = False
score = 0
# fresh env
state = env.reset()
state = np.reshape(state, [1, 15])
while not done:
global_step += 1
# get action for the current state and go one step in environment
action = agent.get_action(state)
next_state, reward, done = env.step(action)
next_state = np.reshape(next_state, [1, 15])
agent.append_sample(state, action, reward)
score += reward
state = copy.deepcopy(next_state)
if done:
# update policy neural network for each episode
agent.train_model()
scores.append(score)
episodes.append(e)
score = round(score, 2)
print("episode:", e, " score:", score, " time_step:",
global_step)
if e % 100 == 0:
pylab.plot(episodes, scores, 'b')
pylab.savefig("./save_graph/reinforce.png")
agent.model.save_weights("./save_model/reinforce.h5")