diff --git a/tensor2tensor/models/research/rl.py b/tensor2tensor/models/research/rl.py index 4cb2bf161..a3ed18a6d 100644 --- a/tensor2tensor/models/research/rl.py +++ b/tensor2tensor/models/research/rl.py @@ -377,6 +377,7 @@ def dqn_atari_base(): agent_epsilon_eval=0.001, agent_epsilon_decay_period=250000, # agent steps agent_generates_trainable_dones=True, + agent_type="VanillaDQN", # one of ["Rainbow", "VanillaDQN"] optimizer_class="RMSProp", optimizer_learning_rate=0.00025, @@ -420,6 +421,14 @@ def dqn_guess1_params(): return hparams +@registry.register_hparams +def dqn_guess1_rainbow_params(): + """Guess 1 for DQN params.""" + hparams = dqn_guess1_params() + hparams.set_hparam("agent_type", "Rainbow") + return hparams + + @registry.register_hparams def dqn_2m_replay_buffer_params(): """Guess 1 for DQN params, 2 milions transitions in replay buffer.""" diff --git a/tensor2tensor/rl/dopamine_connector.py b/tensor2tensor/rl/dopamine_connector.py index 45605c793..a1ab21bc1 100644 --- a/tensor2tensor/rl/dopamine_connector.py +++ b/tensor2tensor/rl/dopamine_connector.py @@ -24,14 +24,19 @@ import sys from dopamine.agents.dqn import dqn_agent +from dopamine.agents.rainbow import rainbow_agent from dopamine.replay_memory import circular_replay_buffer -from dopamine.replay_memory.circular_replay_buffer import OutOfGraphReplayBuffer -from dopamine.replay_memory.circular_replay_buffer import ReplayElement +from dopamine.replay_memory.circular_replay_buffer import \ + OutOfGraphReplayBuffer, ReplayElement +from dopamine.replay_memory.prioritized_replay_buffer import \ + OutOfGraphPrioritizedReplayBuffer, WrappedPrioritizedReplayBuffer import numpy as np + from tensor2tensor.rl.policy_learner import PolicyLearner import tensorflow as tf # pylint: disable=g-import-not-at-top +# pylint: disable=ungrouped-imports try: import cv2 except ImportError: @@ -41,7 +46,18 @@ except ImportError: run_experiment = None # pylint: enable=g-import-not-at-top - +# pylint: enable=ungrouped-imports + +# TODO: Vanilla DQN and Rainbow have a lot of common code. Most likely we want +# to remove Vanilla DQN and only have Rainbow. To do so one needs to remove +# following: +# * _DQNAgent +# * BatchDQNAgent +# * _OutOfGraphReplayBuffer +# * "if" clause in create_agent() +# * parameter "agent_type" from dqn_atari_base() hparams and possibly other +# rlmb dqn hparams sets +# If we want to keep both Vanilla DQN and Rainbow, larger refactor is required. class _DQNAgent(dqn_agent.DQNAgent): """Modify dopamine DQNAgent to match our needs. @@ -178,6 +194,201 @@ def choose_action(ix): return np.array([choose_action(ix) for ix in range(self.env_batch_size)]) +class _OutOfGraphReplayBuffer(OutOfGraphReplayBuffer): + """Replay not sampling artificial_terminal transition. + + Adds to stored tuples "artificial_done" field (as last ReplayElement). + When sampling, ignores tuples for which artificial_done is True. + + When adding new attributes check if there are loaded from disk, when using + load() method. + + Attributes: + are_terminal_valid: A boolean indicating if newly added terminal + transitions should be marked as artificially done. Replay data loaded + from disk will not be overridden. + """ + + def __init__(self, artificial_done, **kwargs): + extra_storage_types = kwargs.pop("extra_storage_types", None) or [] + extra_storage_types.append(ReplayElement("artificial_done", (), np.uint8)) + super(_OutOfGraphReplayBuffer, self).__init__( + extra_storage_types=extra_storage_types, **kwargs) + self._artificial_done = artificial_done + + def is_valid_transition(self, index): + valid = super(_OutOfGraphReplayBuffer, self).is_valid_transition(index) + valid &= not self.get_artificial_done_stack(index).any() + return valid + + def get_artificial_done_stack(self, index): + return self.get_range(self._store["artificial_done"], + index - self._stack_size + 1, index + 1) + + def add(self, observation, action, reward, terminal, *args): + """Append artificial_done to *args and run parent method.""" + # If this will be a problem for maintenance, we could probably override + # DQNAgent.add() method instead. + artificial_done = self._artificial_done and terminal + args = list(args) + args.append(artificial_done) + return super(_OutOfGraphReplayBuffer, self).add(observation, action, reward, + terminal, *args) + + def load(self, *args, **kwargs): + # Check that appropriate attributes are not overridden + are_terminal_valid = self._artificial_done + super(_OutOfGraphReplayBuffer, self).load(*args, **kwargs) + assert self._artificial_done == are_terminal_valid + + +class _WrappedPrioritizedReplayBuffer(WrappedPrioritizedReplayBuffer): + """ + + Allows to pass out-of-graph-replay-buffer via wrapped_memory. + """ + def __init__(self, wrapped_memory, batch_size, use_staging): + self.batch_size = batch_size + self.memory = wrapped_memory + self.create_sampling_ops(use_staging) + + +class _RainbowAgent(rainbow_agent.RainbowAgent): + """Modify dopamine DQNAgent to match our needs. + + Allow passing batch_size and replay_capacity to ReplayBuffer, allow not using + (some of) terminal episode transitions in training. + """ + + def __init__(self, replay_capacity, buffer_batch_size, + generates_trainable_dones, **kwargs): + self._replay_capacity = replay_capacity + self._buffer_batch_size = buffer_batch_size + self._generates_trainable_dones = generates_trainable_dones + super(_RainbowAgent, self).__init__(**kwargs) + + def _build_replay_buffer(self, use_staging): + """Build WrappedReplayBuffer with custom OutOfGraphReplayBuffer.""" + replay_buffer_kwargs = dict( + observation_shape=dqn_agent.NATURE_DQN_OBSERVATION_SHAPE, + stack_size=dqn_agent.NATURE_DQN_STACK_SIZE, + replay_capacity=self._replay_capacity, + batch_size=self._buffer_batch_size, + update_horizon=self.update_horizon, + gamma=self.gamma, + extra_storage_types=None, + observation_dtype=np.uint8, + ) + + replay_memory = _OutOfGraphPrioritizedReplayBuffer( + artificial_done=not self._generates_trainable_dones, + **replay_buffer_kwargs) + + return _WrappedPrioritizedReplayBuffer( + wrapped_memory=replay_memory, + use_staging=use_staging, batch_size=self._buffer_batch_size) + # **replay_buffer_kwargs) + + +class BatchRainbowAgent(_RainbowAgent): + """Batch agent for DQN. + + Episodes are stored on done. + + Assumes that all rollouts in batch would end at the same moment. + """ + + def __init__(self, env_batch_size, *args, **kwargs): + super(BatchRainbowAgent, self).__init__(*args, **kwargs) + self.env_batch_size = env_batch_size + obs_size = dqn_agent.NATURE_DQN_OBSERVATION_SHAPE + state_shape = [self.env_batch_size, obs_size[0], obs_size[1], + dqn_agent.NATURE_DQN_STACK_SIZE] + self.state_batch = np.zeros(state_shape) + self.state = None # assure it will be not used + self._observation = None # assure it will be not used + self.reset_current_rollouts() + + def reset_current_rollouts(self): + self._current_rollouts = [[] for _ in range(self.env_batch_size)] + + def _record_observation(self, observation_batch): + # Set current observation. Represents an (batch_size x 84 x 84 x 1) image + # frame. + observation_batch = np.array(observation_batch) + self._observation_batch = observation_batch[:, :, :, 0] + # Swap out the oldest frames with the current frames. + self.state_batch = np.roll(self.state_batch, -1, axis=3) + self.state_batch[:, :, :, -1] = self._observation_batch + + def _reset_state(self): + self.state_batch.fill(0) + + def begin_episode(self, observation): + self._reset_state() + self._record_observation(observation) + + if not self.eval_mode: + self._train_step() + + self.action = self._select_action() + return self.action + + def _update_current_rollouts(self, last_observation, action, reward, + are_terminal): + transitions = zip(last_observation, action, reward, are_terminal) + for transition, rollout in zip(transitions, self._current_rollouts): + rollout.append(transition) + + def _store_current_rollouts(self): + for rollout in self._current_rollouts: + for transition in rollout: + self._store_transition(*transition) + self.reset_current_rollouts() + + def step(self, reward, observation): + self._last_observation = self._observation_batch + self._record_observation(observation) + + if not self.eval_mode: + self._update_current_rollouts(self._last_observation, self.action, reward, + [False] * self.env_batch_size) + # We want to have the same train_step:env_step ratio not depending on + # batch size. + for _ in range(self.env_batch_size): + self._train_step() + + self.action = self._select_action() + return self.action + + def end_episode(self, reward): + if not self.eval_mode: + self._update_current_rollouts( + self._observation_batch, self.action, reward, + [True] * self.env_batch_size) + self._store_current_rollouts() + + def _select_action(self): + epsilon = self.epsilon_eval + if not self.eval_mode: + epsilon = self.epsilon_fn( + self.epsilon_decay_period, + self.training_steps, + self.min_replay_history, + self.epsilon_train) + + def choose_action(ix): + if random.random() <= epsilon: + # Choose a random action with probability epsilon. + return random.randint(0, self.num_actions - 1) + else: + # Choose the action with highest Q-value at the current state. + return self._sess.run(self._q_argmax, + {self.state_ph: self.state_batch[ix:ix+1]}) + + return np.array([choose_action(ix) for ix in range(self.env_batch_size)]) + + class BatchRunner(run_experiment.Runner): """Run a batch of environments. @@ -223,7 +434,7 @@ def close(self): self._environment.close() -class _OutOfGraphReplayBuffer(OutOfGraphReplayBuffer): +class _OutOfGraphPrioritizedReplayBuffer(OutOfGraphPrioritizedReplayBuffer): """Replay not sampling artificial_terminal transition. Adds to stored tuples "artificial_done" field (as last ReplayElement). @@ -240,34 +451,47 @@ class _OutOfGraphReplayBuffer(OutOfGraphReplayBuffer): def __init__(self, artificial_done, **kwargs): extra_storage_types = kwargs.pop("extra_storage_types", None) or [] + assert not extra_storage_types, "Other extra_storage_types are " \ + "currently not supported for this " \ + "class." extra_storage_types.append(ReplayElement("artificial_done", (), np.uint8)) - super(_OutOfGraphReplayBuffer, self).__init__( + super(_OutOfGraphPrioritizedReplayBuffer, self).__init__( extra_storage_types=extra_storage_types, **kwargs) self._artificial_done = artificial_done def is_valid_transition(self, index): - valid = super(_OutOfGraphReplayBuffer, self).is_valid_transition(index) - valid &= not self.get_artificial_done_stack(index).any() + valid = super(_OutOfGraphPrioritizedReplayBuffer, self).\ + is_valid_transition(index) + if valid: + valid = not self.get_artificial_done_stack(index).any() return valid def get_artificial_done_stack(self, index): return self.get_range(self._store["artificial_done"], index - self._stack_size + 1, index + 1) - def add(self, observation, action, reward, terminal, *args): - """Append artificial_done to *args and run parent method.""" + def add(self, observation, action, reward, terminal, priority): + """Infer artificial_done and call parent method. + + Note that OutOfGraphPrioritizedReplayBuffer (implicitly) assumes that + priority would be last argument in add. Here we write it explicitly. + Passing *args to this method is disabled on purpose, code start to gets to + convoluted with it. + """ # If this will be a problem for maintenance, we could probably override # DQNAgent.add() method instead. + if not isinstance(priority, (float, np.floating)): + raise ValueError("priority should be float, got type {}" + .format(type(priority))) artificial_done = self._artificial_done and terminal - args = list(args) - args.append(artificial_done) - return super(_OutOfGraphReplayBuffer, self).add(observation, action, reward, - terminal, *args) + return super(_OutOfGraphPrioritizedReplayBuffer, self).add( + observation, action, reward, terminal, artificial_done, priority + ) def load(self, *args, **kwargs): # Check that appropriate attributes are not overridden are_terminal_valid = self._artificial_done - super(_OutOfGraphReplayBuffer, self).load(*args, **kwargs) + super(_OutOfGraphPrioritizedReplayBuffer, self).load(*args, **kwargs) assert self._artificial_done == are_terminal_valid @@ -280,6 +504,8 @@ def get_create_agent(agent_kwargs): Returns: Function(sess, environment, summary_writer) -> BatchDQNAgent instance. """ + agent_kwargs = copy.deepcopy(agent_kwargs) + agent_type = agent_kwargs.pop("type") def create_agent(sess, environment, summary_writer=None): """Creates a DQN agent. @@ -294,13 +520,24 @@ def create_agent(sess, environment, summary_writer=None): Returns: a DQN agent. """ - return BatchDQNAgent( - env_batch_size=environment.batch_size, - sess=sess, - num_actions=environment.action_space.n, - summary_writer=summary_writer, - tf_device="/gpu:*", - **agent_kwargs) + if agent_type == "Rainbow": + return BatchRainbowAgent( + env_batch_size=environment.batch_size, + sess=sess, + num_actions=environment.action_space.n, + summary_writer=summary_writer, + tf_device="/gpu:*", + **agent_kwargs) + elif agent_type == "VanillaDQN": + return BatchDQNAgent( + env_batch_size=environment.batch_size, + sess=sess, + num_actions=environment.action_space.n, + summary_writer=summary_writer, + tf_device="/gpu:*", + **agent_kwargs) + else: + raise ValueError("Unknown agent_type {}".format(agent_type)) return create_agent diff --git a/tensor2tensor/rl/trainer_model_based_params.py b/tensor2tensor/rl/trainer_model_based_params.py index 4c5350fed..185d2b522 100644 --- a/tensor2tensor/rl/trainer_model_based_params.py +++ b/tensor2tensor/rl/trainer_model_based_params.py @@ -221,6 +221,14 @@ def rlmb_dqn_guess1(): return hparams +@registry.register_hparams +def rlmb_dqn_guess1_rainbow(): + """rlmb_dqn guess1 params""" + hparams = rlmb_dqn_guess1() + hparams.set_hparam("base_algo_params", "dqn_guess1_rainbow_params") + return hparams + + @registry.register_hparams def rlmb_dqn_guess1_2m_replay_buffer(): """DQN guess1 params, 2M replay buffer."""