diff --git a/rl/agent/double_dqn.py b/rl/agent/double_dqn.py index 556baca..9a6cf45 100644 --- a/rl/agent/double_dqn.py +++ b/rl/agent/double_dqn.py @@ -31,6 +31,25 @@ def compile_model(self): optimizer=self.optimizer.keras_optimizer_2) logger.info("Models 1 and 2 compiled") + def switch_models(self): + # Switch model 1 and model 2, also the optimizers + temp = self.model + self.model = self.model_2 + self.model_2 = temp + + temp_optimizer = self.optimizer.keras_optimizer + self.optimizer.keras_optimizer = self.optimizer.keras_optimizer_2 + self.optimizer.keras_optimizer_2 = temp_optimizer + + def recompile_model(self, sys_vars): + '''rotate and recompile both models''' + if self.epi_change_lr is not None: + self.switch_models() # to model_2 + self.recompile_model(sys_vars) + self.switch_models() # back to model + self.recompile_model(sys_vars) + return self.model + def compute_Q_states(self, minibatch): (Q_states, Q_next_states_select, _max) = super( DoubleDQN, self).compute_Q_states(minibatch) @@ -45,16 +64,6 @@ def compute_Q_states(self, minibatch): return (Q_states, Q_next_states, Q_next_states_max) - def switch_models(self): - # Switch model 1 and model 2, also the optimizers - temp = self.model - self.model = self.model_2 - self.model_2 = temp - - temp_optimizer = self.optimizer.keras_optimizer - self.optimizer.keras_optimizer = self.optimizer.keras_optimizer_2 - self.optimizer.keras_optimizer_2 = temp_optimizer - def train_an_epoch(self): self.switch_models() return super(DoubleDQN, self).train_an_epoch()