Skip to content

Commit

Permalink
Merge pull request #131 from kengz/schedule
Browse files Browse the repository at this point in the history
fine tune PER, fix bugs
  • Loading branch information
kengz authored Apr 27, 2017
2 parents 49be028 + d54676e commit 4f123b8
Show file tree
Hide file tree
Showing 15 changed files with 286 additions and 281 deletions.
4 changes: 4 additions & 0 deletions rl/agent/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ def train_critic(self, minibatch):
actor_delta = Q_next_vals - Q_vals
loss = self.critic.train_on_batch(minibatch['states'], Q_targets)

# update memory, needed for PER
errors = abs(np.sum(Q_vals - Q_targets, axis=1))
# Q size is only 1, from critic
assert Q_targets.shape == (self.batch_size, 1)
assert errors.shape == (self.batch_size, )
self.memory.update(errors)
return loss, actor_delta

Expand Down
8 changes: 8 additions & 0 deletions rl/agent/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def train_an_epoch(self):

# train critic
mu_prime = self.actor.target_predict(minibatch['next_states'])
q_val = self.critic.target_predict(minibatch['states'], mu_prime)
q_prime = self.critic.target_predict(
minibatch['next_states'], mu_prime)
# reshape for element-wise multiplication
Expand All @@ -250,6 +251,13 @@ def train_an_epoch(self):
(1 - minibatch['terminals']) * np.reshape(q_prime, (-1))
y = np.reshape(y, (-1, 1))

# update memory, needed for PER
errors = abs(np.sum(q_val - y, axis=1))
# Q size is only 1, from critic
assert y.shape == (self.batch_size, 1)
assert errors.shape == (self.batch_size, )
self.memory.update(errors)

_, _, critic_loss = self.critic.train_tf(
minibatch['states'], minibatch['actions'], y)

Expand Down
7 changes: 7 additions & 0 deletions rl/agent/deep_sarsa.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from rl.agent.dqn import DQN


Expand Down Expand Up @@ -30,4 +31,10 @@ def train_an_epoch(self):
Q_targets = self.compute_Q_targets(
minibatch, Q_states, Q_next_states_selected)
loss = self.model.train_on_batch(minibatch['states'], Q_targets)

errors = abs(np.sum(Q_states - Q_targets, axis=1))
assert Q_targets.shape == (
self.batch_size, self.env_spec['action_dim'])
assert errors.shape == (self.batch_size, )
self.memory.update(errors)
return loss
4 changes: 3 additions & 1 deletion rl/agent/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,12 @@ def train_an_epoch(self):
minibatch)
Q_targets = self.compute_Q_targets(
minibatch, Q_states, Q_next_states_max)

loss = self.model.train_on_batch(minibatch['states'], Q_targets)

errors = abs(np.sum(Q_states - Q_targets, axis=1))
assert Q_targets.shape == (
self.batch_size, self.env_spec['action_dim'])
assert errors.shape == (self.batch_size, )
self.memory.update(errors)
return loss

Expand Down
7 changes: 4 additions & 3 deletions rl/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,10 @@ def compose_data(trial):
}

# param variables for independent vars of trials
default_param = trial.experiment_spec['param']
param_variables = {
pv: trial.experiment_spec['param'][pv] for
pv in trial.param_variables}
pv: default_param[pv] for
pv in trial.param_variables if pv in default_param}

trial.data['metrics'].update(metrics)
trial.data['param_variables'] = param_variables
Expand Down Expand Up @@ -459,7 +460,7 @@ def analyze_data(experiment_data_or_experiment_id):

data_df.sort_values(
['fitness_score'], ascending=False, inplace=True)
data_df.reset_index(inplace=True)
data_df.reset_index(drop=True, inplace=True)

trial_id = experiment_data[0]['trial_id']
save_experiment_data(data_df, trial_id)
Expand Down
23 changes: 12 additions & 11 deletions rl/memory/prioritized_exp_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@ class PrioritizedExperienceReplay(LinearMemoryWithForgetting):
memory unit
'''

def __init__(self, env_spec, max_mem_len=10000, e=0.01, alpha=0.6,
def __init__(self, env_spec, max_mem_len=None, e=0.01, alpha=0.6,
**kwargs):
if max_mem_len is None: # auto calculate mem len
max_timestep = env_spec['timestep_limit']
max_epis = env_spec['problem']['MAX_EPISODES']
memory_epi = np.ceil(max_epis / 3.).astype(int)
max_mem_len = max(10**6, max_timestep * memory_epi)
super(PrioritizedExperienceReplay, self).__init__(
env_spec, max_mem_len)
self.exp_keys.append('error')
Expand All @@ -27,21 +32,18 @@ def __init__(self, env_spec, max_mem_len=10000, e=0.01, alpha=0.6,
self.prio_tree = SumTree(self.max_mem_len)
self.head = 0

# bump to account for negative terms in reward get_priority
# and we cannot abs(reward) cuz it's sign sensitive
SOLVED_MEAN_REWARD = self.env_spec['problem']['SOLVED_MEAN_REWARD'] or 10000
self.min_priority = abs(10 * SOLVED_MEAN_REWARD)

def get_priority(self, error):
# add min_priority to prevent root of negative = complex
p = (self.min_priority + error + self.e) ** self.alpha
assert not np.isnan(p)
p = (error + self.e) ** self.alpha
assert np.isfinite(p)
return p

def add_exp(self, action, reward, next_state, terminal):
'''Round robin memory updating'''
# roughly the error between estimated Q and true q is the reward
error = reward
# init error to reward first, update later
error = abs(reward)
p = self.get_priority(error)

if self.size() < self.max_mem_len: # add as usual
super(PrioritizedExperienceReplay, self).add_exp(
action, reward, next_state, terminal)
Expand All @@ -59,7 +61,6 @@ def add_exp(self, action, reward, next_state, terminal):
if self.head >= self.max_mem_len:
self.head = 0 # reset for round robin

p = self.get_priority(error)
self.prio_tree.add(p)

assert self.head == self.prio_tree.head, 'prio_tree head is wrong'
Expand Down
10 changes: 6 additions & 4 deletions rl/policy/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,18 @@ class SoftmaxPolicy(Policy):
def __init__(self, env_spec,
**kwargs): # absorb generic param without breaking
super(SoftmaxPolicy, self).__init__(env_spec)
self.clip_val = 500
self.clip_val = 500.
log_self(self)

def select_action(self, state):
agent = self.agent
state = np.expand_dims(state, axis=0)
A_score = agent.actor.predict(state)[0] # extract from batch predict
assert A_score.ndim == 1
A_score = A_score.astype('float32') # fix precision nan issue
A_score = A_score - np.amax(A_score) # prevent overflow
A_score = A_score.astype('float64') # fix precision overflow
exp_values = np.exp(
np.clip(A_score, -self.clip_val, self.clip_val))
assert not np.isnan(exp_values).any()
assert np.isfinite(exp_values).all()
probs = np.array(exp_values / np.sum(exp_values))
probs /= probs.sum() # renormalize to prevent floating pt error
action = np.random.choice(agent.env_spec['actions'], p=probs)
Expand Down Expand Up @@ -83,6 +82,9 @@ def select_action(self, state):
a_mean = agent.actor.predict(state)[0] # extract from batch predict
action = a_mean + np.random.normal(
loc=0.0, scale=self.variance, size=a_mean.shape)
action = np.clip(action,
self.env_spec['action_bound_low'],
self.env_spec['action_bound_high'])
return action

def update(self, sys_vars):
Expand Down
12 changes: 5 additions & 7 deletions rl/policy/boltzmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@ def __init__(self, env_spec,
self.final_tau = final_tau
self.tau = self.init_tau
self.exploration_anneal_episodes = exploration_anneal_episodes
self.clip_val = 500
self.clip_val = 500.
log_self(self)

def select_action(self, state):
agent = self.agent
state = np.expand_dims(state, axis=0)
Q_state = agent.model.predict(state)[0] # extract from batch predict
assert Q_state.ndim == 1
Q_state = Q_state.astype('float32') # fix precision nan issue
Q_state = Q_state - np.amax(Q_state) # prevent overflow
Q_state = Q_state.astype('float64') # fix precision overflow
exp_values = np.exp(
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
assert not np.isnan(exp_values).any()
assert np.isfinite(exp_values).all()
probs = np.array(exp_values / np.sum(exp_values))
probs /= probs.sum() # renormalize to prevent floating pt error
action = np.random.choice(agent.env_spec['actions'], p=probs)
Expand Down Expand Up @@ -66,11 +65,10 @@ def select_action(self, state):
Q_state2 = agent.model_2.predict(state)[0]
Q_state = Q_state1 + Q_state2
assert Q_state.ndim == 1
Q_state = Q_state.astype('float32') # fix precision nan issue
Q_state = Q_state - np.amax(Q_state) # prevent overflow
Q_state = Q_state.astype('float64') # fix precision overflow
exp_values = np.exp(
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
assert not np.isnan(exp_values).any()
assert np.isfinite(exp_values).all()
probs = np.array(exp_values / np.sum(exp_values))
probs /= probs.sum() # renormalize to prevent floating pt error
action = np.random.choice(agent.env_spec['actions'], p=probs)
Expand Down
24 changes: 24 additions & 0 deletions rl/policy/noise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from rl.util import log_self
from rl.policy.base_policy import Policy
from rl.policy.epsilon_greedy import EpsilonGreedyPolicy


class NoNoisePolicy(Policy):
Expand All @@ -25,6 +26,9 @@ def select_action(self, state):
state = np.expand_dims(state, axis=0)
if self.env_spec['actions'] == 'continuous':
action = agent.actor.predict(state)[0] + self.sample()
action = np.clip(action,
self.env_spec['action_bound_low'],
self.env_spec['action_bound_high'])
else:
Q_state = agent.actor.predict(state)[0]
assert Q_state.ndim == 1
Expand Down Expand Up @@ -60,6 +64,26 @@ def update(self, sys_vars):
self.n_step = sys_vars['epi']


class EpsilonGreedyNoisePolicy(EpsilonGreedyPolicy, NoNoisePolicy):

'''
akin to epsilon greedy decay,
but return random sample instead
'''

def sample(self):
if self.e > np.random.rand():
noise = np.random.uniform(
0.5 * self.env_spec['action_bound_low'],
0.5 * self.env_spec['action_bound_high'])
else:
noise = 0
return noise

def select_action(self, state):
return NoNoisePolicy.select_action(self, state)


class AnnealedGaussianPolicy(LinearNoisePolicy):

'''
Expand Down
63 changes: 32 additions & 31 deletions rl/spec/box2d_experiment_specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,35 +97,6 @@
]
}
},
"lunar_double_dqn_per": {
"problem": "LunarLander-v2",
"Agent": "DoubleDQN",
"HyperOptimizer": "GridSearch",
"Memory": "PrioritizedExperienceReplay",
"Optimizer": "AdamOptimizer",
"Policy": "DoubleDQNBoltzmannPolicy",
"PreProcessor": "StackStates",
"param": {
"train_per_n_new_exp": 2,
"lr": 0.005,
"gamma": 0.99,
"hidden_layers": [800, 400],
"hidden_layers_activation": "sigmoid",
"output_layer_activation": "linear",
"exploration_anneal_episodes": 150,
"epi_change_lr": 200,
"max_mem_len": 30000
},
"param_range": {
"lr": [0.001, 0.005, 0.01],
"gamma": [0.97, 0.99, 0.999],
"hidden_layers": [
[400, 200],
[800, 400],
[400, 200, 100]
]
}
},
"lunar_double_dqn_nopreprocess": {
"problem": "LunarLander-v2",
"Agent": "DoubleDQN",
Expand Down Expand Up @@ -266,11 +237,11 @@
]
}
},
"lunar_ddpg_linearnoise": {
"lunar_cont_ddpg_per_linearnoise": {
"problem": "LunarLanderContinuous-v2",
"Agent": "DDPG",
"HyperOptimizer": "GridSearch",
"Memory": "LinearMemoryWithForgetting",
"Memory": "PrioritizedExperienceReplay",
"Optimizer": "AdamOptimizer",
"Policy": "LinearNoisePolicy",
"PreProcessor": "NoPreProcessor",
Expand Down Expand Up @@ -327,5 +298,35 @@
[800, 400, 200]
]
}
},
"walker_ddpg_per_linearnoise": {
"problem": "BipedalWalker-v2",
"Agent": "DDPG",
"HyperOptimizer": "GridSearch",
"Memory": "PrioritizedExperienceReplay",
"Optimizer": "AdamOptimizer",
"Policy": "LinearNoisePolicy",
"PreProcessor": "NoPreProcessor",
"param": {
"batch_size": 64,
"n_epoch": 1,
"tau": 0.005,
"lr": 0.0005,
"critic_lr": 0.001,
"gamma": 0.97,
"hidden_layers": [400, 200],
"hidden_layers_activation": "relu",
"output_layer_activation": "tanh"
},
"param_range": {
"lr": [0.0001, 0.0005],
"critic_lr": [0.001, 0.005],
"gamma": [0.95, 0.97, 0.99],
"hidden_layers": [
[200, 100],
[400, 300],
[800, 400]
]
}
}
}
Loading

0 comments on commit 4f123b8

Please sign in to comment.