Skip to content

Commit

Permalink
Merge pull request #109 from kengz/boltzmann-fix
Browse files Browse the repository at this point in the history
State reshape fix in action selection
  • Loading branch information
kengz authored Apr 5, 2017
2 parents 5fb91d3 + 67d38ab commit 4bdb8ef
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 51 deletions.
3 changes: 1 addition & 2 deletions rl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def init_sys_vars(self):
sys_vars = PROBLEMS[self.problem]
if args.max_epis >= 0:
sys_vars['MAX_EPISODES'] = args.max_epis
if not args.render:
sys_vars['RENDER'] = False
sys_vars['RENDER'] = (not args.param_selection and args.render)
if environ.get('CI'):
sys_vars['RENDER'] = False
self.sys_vars = sys_vars
Expand Down
14 changes: 7 additions & 7 deletions rl/policy/boltzmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@ def __init__(self, env_spec,
self.final_tau = final_tau
self.tau = self.init_tau
self.exploration_anneal_episodes = exploration_anneal_episodes
self.clip_val = 100000
self.clip_val = 500
log_self(self)

def select_action(self, state):
agent = self.agent
state = np.reshape(state, (1, state.shape[0]))
state = np.expand_dims(state, axis=0)
Q_state = agent.model.predict(state)[0] # extract from batch predict
assert Q_state.ndim == 1
Q_state = Q_state.astype('float32') # fix precision nan issue
Q_state = Q_state - np.amax(Q_state) # prevent overflow
exp_values = np.clip(
np.exp(Q_state / self.tau), -self.clip_val, self.clip_val)
exp_values = np.exp(
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
assert not np.isnan(exp_values).any()
probs = np.array(exp_values / np.sum(exp_values))
probs /= probs.sum() # renormalize to prevent floating pt error
Expand Down Expand Up @@ -60,16 +60,16 @@ def __init__(self, env_spec,

def select_action(self, state):
agent = self.agent
state = np.reshape(state, (1, state.shape[0]))
state = np.expand_dims(state, axis=0)
# extract from batch predict
Q_state1 = agent.model.predict(state)[0]
Q_state2 = agent.model_2.predict(state)[0]
Q_state = Q_state1 + Q_state2
assert Q_state.ndim == 1
Q_state = Q_state.astype('float32') # fix precision nan issue
Q_state = Q_state - np.amax(Q_state) # prevent overflow
exp_values = np.clip(
np.exp(Q_state / self.tau), -self.clip_val, self.clip_val)
exp_values = np.exp(
np.clip(Q_state / self.tau, -self.clip_val, self.clip_val))
assert not np.isnan(exp_values).any()
probs = np.array(exp_values / np.sum(exp_values))
probs /= probs.sum() # renormalize to prevent floating pt error
Expand Down
4 changes: 2 additions & 2 deletions rl/policy/epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def update(self, sys_vars):
return self.e


class DoubleDQNPolicy(EpsilonGreedyPolicy):
class DoubleDQNEpsilonGreedyPolicy(EpsilonGreedyPolicy):

'''
Policy to accompany double dqn agents
Expand All @@ -54,7 +54,7 @@ class DoubleDQNPolicy(EpsilonGreedyPolicy):
def __init__(self, env_spec,
init_e=1.0, final_e=0.1, exploration_anneal_episodes=30,
**kwargs): # absorb generic param without breaking
super(DoubleDQNPolicy, self).__init__(
super(DoubleDQNEpsilonGreedyPolicy, self).__init__(
env_spec, init_e, final_e,
exploration_anneal_episodes)

Expand Down
32 changes: 16 additions & 16 deletions rl/spec/atari_experiment_specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
[32, 4, 4, [2, 2]]
],
"hidden_layers_activation": "relu",
"exploration_anneal_episodes": 5000,
"epi_change_lr": 5000,
"exploration_anneal_episodes": 3000,
"epi_change_lr": 3000,
"auto_architecture": true,
"num_hidden_layers" : 3,
"num_initial_channels" : 8,
"mem_size" : 500000
"max_mem_len" : 500000

},
"param_range": {
Expand All @@ -41,25 +41,25 @@
"HyperOptimizer": "GridSearch",
"Memory": "LinearMemoryWithForgetting",
"Optimizer": "AdamOptimizer",
"Policy": "EpsilonGreedyPolicy",
"Policy": "BoltzmannPolicy",
"PreProcessor": "Atari",
"param": {
"train_per_n_new_exp": 4,
"batch_size": 32,
"lr": 0.001,
"gamma": 0.99,
"hidden_layers": [
[16, 8, 8, [4, 4]],
[32, 4, 4, [2, 2]]
[32, 8, 8, [4, 4]],
[64, 4, 4, [2, 2]],
[64, 3, 3, [1, 1]]
],
"hidden_layers_activation": "relu",
"exploration_anneal_episodes": 5000,
"epi_change_lr": 5000,
"mem_size" : 500000
"exploration_anneal_episodes": 3000,
"epi_change_lr": 3000,
"max_mem_len" : 500000
},
"param_range": {
"lr": [0.001, 0.0001],
"gamma": [0.97, 0.99]
"lr": [0.001, 0.01]
}
},
"breakout_double_dqn": {
Expand All @@ -68,7 +68,7 @@
"HyperOptimizer": "GridSearch",
"Memory": "LinearMemoryWithForgetting",
"Optimizer": "AdamOptimizer",
"Policy": "EpsilonGreedyPolicy",
"Policy": "DoubleDQNBoltzmannPolicy",
"PreProcessor": "Atari",
"param": {
"train_per_n_new_exp": 4,
Expand All @@ -80,9 +80,9 @@
[32, 4, 4, [2, 2]]
],
"hidden_layers_activation": "relu",
"exploration_anneal_episodes": 5000,
"epi_change_lr": 5000,
"mem_size" : 500000
"exploration_anneal_episodes": 3000,
"epi_change_lr": 3000,
"max_mem_len" : 500000
},
"param_range": {
"lr": [0.001, 0.0001],
Expand All @@ -109,7 +109,7 @@
"hidden_layers_activation": "relu",
"exploration_anneal_episodes": 10000,
"epi_change_lr": 10000,
"mem_size" : 500000
"max_mem_len" : 500000
},
"param_range": {
"lr": [0.001, 0.0001],
Expand Down
24 changes: 0 additions & 24 deletions rl/spec/problems.json
Original file line number Diff line number Diff line change
@@ -1,167 +1,143 @@
{
"DevCartPole-v0": {
"RENDER": true,
"GYM_ENV_NAME": "CartPole-v0",
"SOLVED_MEAN_REWARD": 195.0,
"MAX_EPISODES": 4,
"REWARD_MEAN_LEN": 100
},
"TestPassCartPole-v0": {
"RENDER": true,
"GYM_ENV_NAME": "CartPole-v0",
"SOLVED_MEAN_REWARD": 50.0,
"MAX_EPISODES": 20,
"REWARD_MEAN_LEN": 100
},
"CartPole-v0": {
"RENDER": true,
"GYM_ENV_NAME": "CartPole-v0",
"SOLVED_MEAN_REWARD": 195.0,
"MAX_EPISODES": 250,
"REWARD_MEAN_LEN": 100
},
"CartPole-v1": {
"RENDER": true,
"GYM_ENV_NAME": "CartPole-v1",
"SOLVED_MEAN_REWARD": 475.0,
"MAX_EPISODES": 500,
"REWARD_MEAN_LEN": 100
},
"Acrobot-v1": {
"RENDER": true,
"GYM_ENV_NAME": "Acrobot-v1",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 600,
"REWARD_MEAN_LEN": 100
},
"MountainCar-v0": {
"RENDER": true,
"GYM_ENV_NAME": "MountainCar-v0",
"SOLVED_MEAN_REWARD": -110.0,
"MAX_EPISODES": 1000,
"REWARD_MEAN_LEN": 100
},
"MountainCarContinuous-v0": {
"RENDER": true,
"GYM_ENV_NAME": "MountainCarContinuous-v0",
"SOLVED_MEAN_REWARD": 90.0,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"Pendulum-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Pendulum-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 1000,
"REWARD_MEAN_LEN": 100
},
"LunarLander-v2": {
"RENDER": true,
"GYM_ENV_NAME": "LunarLander-v2",
"SOLVED_MEAN_REWARD": 200.0,
"MAX_EPISODES": 600,
"REWARD_MEAN_LEN": 100
},
"LunarLanderContinuous-v2": {
"RENDER": true,
"GYM_ENV_NAME": "LunarLanderContinuous-v2",
"SOLVED_MEAN_REWARD": 200.0,
"MAX_EPISODES": 800,
"REWARD_MEAN_LEN": 100
},
"BipedalWalker-v2": {
"RENDER": true,
"GYM_ENV_NAME": "BipedalWalker-v2",
"SOLVED_MEAN_REWARD": 300.0,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"BipedalWalkerHardcore-v2": {
"RENDER": true,
"GYM_ENV_NAME": "BipedalWalkerHardcore-v2",
"SOLVED_MEAN_REWARD": 300.0,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"CarRacing-v0": {
"RENDER": true,
"GYM_ENV_NAME": "CarRacing-v0",
"SOLVED_MEAN_REWARD": 900.0,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"AirRaid-v0": {
"RENDER": true,
"GYM_ENV_NAME": "AirRaid-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"Alien-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Alien-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"Assault-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Assault-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"Dev-Breakout-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Breakout-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 1,
"REWARD_MEAN_LEN": 100
},
"Breakout-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Breakout-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"MsPacman-v0": {
"RENDER": true,
"GYM_ENV_NAME": "MsPacman-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"Pong-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Pong-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"Qbert-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Qbert-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"SpaceInvader-v0": {
"RENDER": true,
"GYM_ENV_NAME": "SpaceInvader-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 5000,
"REWARD_MEAN_LEN": 100
},
"FlappyBird-v0": {
"RENDER": true,
"GYM_ENV_NAME": "FlappyBird-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 1000,
"REWARD_MEAN_LEN": 100
},
"Snake-v0": {
"RENDER": true,
"GYM_ENV_NAME": "Snake-v0",
"SOLVED_MEAN_REWARD": null,
"MAX_EPISODES": 1000,
Expand Down
4 changes: 4 additions & 0 deletions test/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ def test_lunar_dqn(cls):
data_df = run('lunar_dqn')
assert isinstance(data_df, pd.DataFrame)

@unittest.skipIf(environ.get('CI'),
"Delay CI test until dev stable")
@classmethod
def test_breakout_dqn(cls):
data_df = run('breakout_dqn')
assert isinstance(data_df, pd.DataFrame)

@unittest.skipIf(environ.get('CI'),
"Delay CI test until dev stable")
@classmethod
def test_breakout_double_dqn(cls):
data_df = run('breakout_double_dqn')
Expand Down

0 comments on commit 4bdb8ef

Please sign in to comment.