Skip to content

Commit

Permalink
Merge pull request #119 from kengz/schedule
Browse files Browse the repository at this point in the history
Fix DoubleDQN recompile_model missing model_2
  • Loading branch information
kengz authored Apr 9, 2017
2 parents f75ba48 + 9e141e5 commit ca0a3cf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
29 changes: 19 additions & 10 deletions rl/agent/double_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ def compile_model(self):
optimizer=self.optimizer.keras_optimizer_2)
logger.info("Models 1 and 2 compiled")

def switch_models(self):
# Switch model 1 and model 2, also the optimizers
temp = self.model
self.model = self.model_2
self.model_2 = temp

temp_optimizer = self.optimizer.keras_optimizer
self.optimizer.keras_optimizer = self.optimizer.keras_optimizer_2
self.optimizer.keras_optimizer_2 = temp_optimizer

def recompile_model(self, sys_vars):
'''rotate and recompile both models'''
if self.epi_change_lr is not None:
self.switch_models() # to model_2
super(DoubleDQN, self).recompile_model(sys_vars)
self.switch_models() # back to model
super(DoubleDQN, self).recompile_model(sys_vars)
return self.model

def compute_Q_states(self, minibatch):
(Q_states, Q_next_states_select, _max) = super(
DoubleDQN, self).compute_Q_states(minibatch)
Expand All @@ -45,16 +64,6 @@ def compute_Q_states(self, minibatch):

return (Q_states, Q_next_states, Q_next_states_max)

def switch_models(self):
# Switch model 1 and model 2, also the optimizers
temp = self.model
self.model = self.model_2
self.model_2 = temp

temp_optimizer = self.optimizer.keras_optimizer
self.optimizer.keras_optimizer = self.optimizer.keras_optimizer_2
self.optimizer.keras_optimizer_2 = temp_optimizer

def train_an_epoch(self):
self.switch_models()
return super(DoubleDQN, self).train_an_epoch()
21 changes: 9 additions & 12 deletions rl/spec/classic_experiment_specs.json
Original file line number Diff line number Diff line change
Expand Up @@ -705,16 +705,15 @@
"hidden_layers": [128, 64],
"hidden_layers_activation": "sigmoid",
"output_layer_activation": "linear",
"exploration_anneal_episodes": 400,
"epi_change_lr": 800
"exploration_anneal_episodes": 50,
"epi_change_lr": 100
},
"param_range": {
"lr": [0.01, 0.02],
"gamma": [0.99, 0.999],
"hidden_layers": [
[200],
[400],
[800]
[400]
]
}
},
Expand All @@ -733,16 +732,15 @@
"hidden_layers": [200],
"hidden_layers_activation": "sigmoid",
"output_layer_activation": "linear",
"exploration_anneal_episodes": 400,
"epi_change_lr": 800
"exploration_anneal_episodes": 50,
"epi_change_lr": 100
},
"param_range": {
"lr": [0.01, 0.02],
"gamma": [0.99, 0.999],
"hidden_layers": [
[200],
[400],
[800]
[400]
]
}
},
Expand Down Expand Up @@ -790,16 +788,15 @@
"hidden_layers": [128, 64],
"hidden_layers_activation": "sigmoid",
"output_layer_activation": "linear",
"exploration_anneal_episodes": 400,
"epi_change_lr": 800
"exploration_anneal_episodes": 50,
"epi_change_lr": 100
},
"param_range": {
"lr": [0.01, 0.02],
"gamma": [0.99, 0.999],
"hidden_layers": [
[200],
[400],
[800]
[400]
]
}
}
Expand Down

0 comments on commit ca0a3cf

Please sign in to comment.