Skip to content
This repository has been archived by the owner on Aug 29, 2024. It is now read-only.

JAX code #30

Open
lucasliunju opened this issue Sep 8, 2022 · 11 comments
Open

JAX code #30

lucasliunju opened this issue Sep 8, 2022 · 11 comments
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@lucasliunju
Copy link

Hi,

I would like to ask whether there is a jax-based code.

And whether there are some recommendations about jax-based offline rl algorithms.

Thanks!

@agarwl
Copy link
Contributor

agarwl commented Sep 10, 2022

Releasing the JAX code might take some time but it should be easy to modify existing dopamine agents. In the meanwhile, here are some tips to get started with jax agents:

  • You can easily modify the JAX dopamine agents to run offline RL on Atari datasets.
  • For the replay buffer, you can use the class FixedReplayBuffer in fixed_replay_buffer.py for creating the offline buffer

@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               summary_writer=None):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    self.replay_data_dir = replay_data_dir
    super().__init__(
        num_actions, update_period=1, summary_writer=summary_writer)

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    return fixed_replay_buffer.FixedReplayBuffer(
        data_dir=self.replay_data_dir,
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype)

  def reload_data(self):
    # This needs to be called every iteration to subsample a portion of the dataset.
    self._replay.reload_data()

  def step(self, reward, observation):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation)
    self._rng, self.action = dqn_agent.select_action(
        self.network_def, self.online_params, self.state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    super()._train_step()

@agarwl agarwl added enhancement New feature or request good first issue Good for newcomers labels Sep 10, 2022
@lucasliunju
Copy link
Author

Thank you very much! May I ask if I can also run the code in TPU-VM with JAX?

Best,
Lucas

@agarwl
Copy link
Contributor

agarwl commented Sep 14, 2022

I think so -- you probably want to use the tfds datasets or use much larger batch sizes with the dopamine codebase.

@lucasliunju
Copy link
Author

Thank you very much! I'll have a try.

@lucasliunju
Copy link
Author

Dear agarwl,

I try to follow your provided code and reproduce the results of offline dqn based on jax. I find the training speed of jax is quite slow compared with TensorFlow. May I ask the possible reason about that. I try to change these parts in the vanilla dopamine code:
(1) I try to rewrite the Runner in dopamine/dopamine/discrete_domains/run_experiment.py based on the code in batch_rl/batch_rl/fixed_replay/run_experiment.py:

@gin.configurable
class FixedReplayRunner(run_experiment.Runner):
  """Object that handles running Dopamine experiments with fixed replay buffer."""

  def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
    super(FixedReplayRunner, self)._initialize_checkpointer_and_maybe_resume(
        checkpoint_file_prefix)

    # Code for the loading a checkpoint at initialization
    init_checkpoint_dir = self._agent._init_checkpoint_dir  # pylint: disable=protected-access
    if (self._start_iteration == 0) and (init_checkpoint_dir is not None):
      if checkpointer.get_latest_checkpoint_number(self._checkpoint_dir) < 0:
        # No checkpoint loaded yet, read init_checkpoint_dir
        init_checkpointer = checkpointer.Checkpointer(
            init_checkpoint_dir, checkpoint_file_prefix)
        latest_init_checkpoint = checkpointer.get_latest_checkpoint_number(
            init_checkpoint_dir)
        if latest_init_checkpoint >= 0:
          experiment_data = init_checkpointer.load_checkpoint(
              latest_init_checkpoint)
          if self._agent.unbundle(
              init_checkpoint_dir, latest_init_checkpoint, experiment_data):
            if experiment_data is not None:
              assert 'logs' in experiment_data
              assert 'current_iteration' in experiment_data
              self._logger.data = experiment_data['logs']
              self._start_iteration = experiment_data['current_iteration'] + 1
            tf.logging.info(
                'Reloaded checkpoint from %s and will start from iteration %d',
                init_checkpoint_dir, self._start_iteration)

  def _run_train_phase(self):
    """Run training phase."""
    self._agent.eval_mode = False
    start_time = time.time()
    for _ in range(self._training_steps):
      self._agent._train_step()  # pylint: disable=protected-access
    time_delta = time.time() - start_time
    tf.logging.info('Average training steps per second: %.2f',
                    self._training_steps / time_delta)

  def _run_one_iteration(self, iteration):
    """Runs one iteration of agent/environment interaction."""
    statistics = iteration_statistics.IterationStatistics()
    tf.logging.info('Starting iteration %d', iteration)
    # pylint: disable=protected-access
    if not self._agent._replay_suffix:
      # Reload the replay buffer
      self._agent._replay.memory.reload_buffer(num_buffers=5)
    # pylint: enable=protected-access
    self._run_train_phase()

    num_episodes_eval, average_reward_eval = self._run_eval_phase(statistics)

    self._save_tensorboard_summaries(
        iteration, num_episodes_eval, average_reward_eval)
    return statistics.data_lists

  def _save_tensorboard_summaries(self, iteration,
                                  num_episodes_eval,
                                  average_reward_eval):
    """Save statistics as tensorboard summaries.
    Args:
      iteration: int, The current iteration number.
      num_episodes_eval: int, number of evaluation episodes run.
      average_reward_eval: float, The average evaluation reward.
    """
    summary = tf.Summary(value=[
        tf.Summary.Value(tag='Eval/NumEpisodes',
                         simple_value=num_episodes_eval),
        tf.Summary.Value(tag='Eval/AverageReturns',
                         simple_value=average_reward_eval)
    ])
    self._summary_writer.add_summary(summary, iteration)

(2) creat offline buffer: fixed_replay_buffer.py

(3) create OfflineJaxDQNAgent:

@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               summary_writer=None):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    self.replay_data_dir = replay_data_dir
    super().__init__(
        num_actions, update_period=1, summary_writer=summary_writer)

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    return fixed_replay_buffer.FixedReplayBuffer(
        data_dir=self.replay_data_dir,
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype)

  def reload_data(self):
    # This needs to be called every iteration to subsample a portion of the dataset.
    self._replay.reload_data()

  def step(self, reward, observation):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation)
    self._rng, self.action = dqn_agent.select_action(
        self.network_def, self.online_params, self.state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    super()._train_step()

(4) I try to compare the difference between the jax code and vanilla tf code, I find they use different repaly buffer (FixedReplayBuffer in JAX and WrappedFixedReplayBuffer in TF). I'm not sure whether this is the main reason.

Best

@lucasliunju
Copy link
Author

Hi I find the update_period is 1 and the tf code is 4. Maybe that is the main reason.

@agarwl
Copy link
Contributor

agarwl commented Jan 28, 2023

Yeah, update_period 1 corresponds to 1 gradient step every environment step (default is 4 which corresponds to 1 grad step every env step). In each iteration, we do 62.5K grad steps, so we can also set num_training_steps to 62.5K with update period 1.

@lucasliunju
Copy link
Author

lucasliunju commented Jan 28, 2023

Hi @agarwl Thanks for your reply. I will try it. By the way, I would like to ask can I run the TF code on TPU-VM? Since I find TF is still a little bit faster.

@agarwl
Copy link
Contributor

agarwl commented Feb 9, 2023

Sure -- you may not see much benefit of using TPUs (due to small batch size and dopamine replay) but the code be run on TPU.

@agarwl
Copy link
Contributor

agarwl commented Apr 10, 2023

@lucasliunju
Copy link
Author

@agarwl Thank you very much! I will have a try. Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants