diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cf2a2a520..5acd634c5 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -224,6 +224,7 @@ New Features: - Improved error message when mixing Gym API with VecEnv API (see GH#1694) - Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss) - Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO) +- Added Prioritized Experience Replay for DQN (@AlexPasqua) Bug Fixes: diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index 85d486661..401fc1a47 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -27,8 +27,9 @@ Notes - Further reference: https://www.nature.com/articles/nature14236 .. note:: - This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay. + This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN or Dueling-DQN. + To Prioritized Experience Replay, you need to pass it via the ``replay_buffer_class`` argument Can I use? ---------- @@ -48,6 +49,15 @@ MultiBinary ❌ ✔️ Dict ❌ ✔️️ ============= ====== =========== +- Rainbow DQN extensions: + + - Double Q-Learning: ❌ + - Prioritized Experience Replay: ✔️ (``from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer``) + - Dueling Networks: ❌ + - Multi-step Learning: ❌ + - Distributional RL: ✔️ (``QR-DQN`` is implemented in the SB3 contrib repo) + - Noisy Nets: ❌ + Example ------- diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index b2fc5a710..cee2fc3b9 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -291,7 +291,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB :param batch_size: Number of element to sample :param env: associated gym VecEnv to normalize the observations/rewards when sampling - :return: + :return: a batch of sampled experiences from the buffer. """ if not self.optimize_memory_usage: return super().sample(batch_size=batch_size, env=env) @@ -321,7 +321,7 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1), self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env), ) - return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + return ReplayBufferSamples(*tuple(map(self.to_torch, data))) # type: ignore[arg-type] @staticmethod def _maybe_cast_dtype(dtype: np.typing.DTypeLike) -> np.typing.DTypeLike: diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py new file mode 100644 index 000000000..f68388365 --- /dev/null +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -0,0 +1,260 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch as th +from gymnasium import spaces + +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.type_aliases import ReplayBufferSamples +from stable_baselines3.common.utils import get_linear_fn +from stable_baselines3.common.vec_env.vec_normalize import VecNormalize + + +class SumTree: + """ + SumTree data structure for Prioritized Replay Buffer. + This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay + + :param buffer_size: Max number of element in the buffer. + """ + + def __init__(self, buffer_size: int) -> None: + self.nodes = np.zeros(2 * buffer_size - 1) + # The data array stores transition indices + self.data = np.zeros(buffer_size) + self.buffer_size = buffer_size + self.pos = 0 + self.full = False + + @property + def total_sum(self) -> float: + """ + Returns the root node value, which represents the total sum of all priorities in the tree. + + :return: Total sum of all priorities in the tree. + """ + return self.nodes[0].item() + + def update(self, leaf_node_idx: int, value: float) -> None: + """ + Update the priority of a leaf node. + + :param leaf_node_idx: Index of the leaf node to update. + :param value: New priority value. + """ + idx = leaf_node_idx + self.buffer_size - 1 # child index in tree array + change = value - self.nodes[idx] + self.nodes[idx] = value + parent = (idx - 1) // 2 + while parent >= 0: + self.nodes[parent] += change + parent = (parent - 1) // 2 + + def add(self, value: float, data: int) -> None: + """ + Add a new transition with priority value, + it adds a new leaf node and update cumulative sum. + + :param value: Priority value. + :param data: Data for the new leaf node, storing transition index + in the case of the prioritized replay buffer. + """ + # Note: transition_indices should be constant + # as the replay buffer already updates a pointer + self.data[self.pos] = data + self.update(self.pos, value) + self.pos = (self.pos + 1) % self.buffer_size + + def get(self, cumulative_sum: float) -> Tuple[int, float, th.Tensor]: + """ + Get a leaf node index, its priority value and transition index by cumulative_sum value. + + :param cumulative_sum: Cumulative sum value. + :return: Leaf node index, its priority value and transition index. + """ + assert cumulative_sum <= self.total_sum + + idx = 0 + while 2 * idx + 1 < len(self.nodes): + left, right = 2 * idx + 1, 2 * idx + 2 + if cumulative_sum <= self.nodes[left]: + idx = left + else: + idx = right + cumulative_sum = cumulative_sum - self.nodes[left] + + leaf_node_idx = idx - self.buffer_size + 1 + return leaf_node_idx, self.nodes[idx].item(), self.data[leaf_node_idx] + + def __repr__(self) -> str: + return f"SumTree(nodes={self.nodes!r}, data={self.data!r})" + + +class PrioritizedReplayBuffer(ReplayBuffer): + """ + Prioritized Replay Buffer (proportional priorities version). + Paper: https://arxiv.org/abs/1511.05952 + This code is inspired by: https://github.com/Howuhh/prioritized_experience_replay + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: PyTorch device + :param n_envs: Number of parallel environments + :param alpha: How much prioritization is used (0 - no prioritization aka uniform case, 1 - full prioritization) + :param beta: To what degree to use importance weights (0 - no corrections, 1 - full correction) + :param final_beta: Value of beta at the end of training. + Linear annealing is used to interpolate between initial value of beta and final beta. + :param min_priority: Minimum priority, prevents zero probabilities, so that all samples + always have a non-zero probability to be sampled. + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "auto", + n_envs: int = 1, + alpha: float = 0.5, + beta: float = 0.4, + final_beta: float = 1.0, + optimize_memory_usage: bool = False, + min_priority: float = 1e-6, + ): + super().__init__(buffer_size, observation_space, action_space, device, n_envs) + + assert optimize_memory_usage is False, "PrioritizedReplayBuffer doesn't support optimize_memory_usage=True" + + self.min_priority = min_priority + self.alpha = alpha + self.max_priority = self.min_priority # priority for new samples, init as eps + # Track the training progress remaining (from 1 to 0) + # this is used to update beta + self._current_progress_remaining = 1.0 + self.inital_beta = beta + self.final_beta = final_beta + self.beta_schedule = get_linear_fn( + self.inital_beta, + self.final_beta, + end_fraction=1.0, + ) + # SumTree: data structure to store priorities + self.tree = SumTree(buffer_size=buffer_size) + + @property + def beta(self) -> float: + # Linear schedule + return self.beta_schedule(self._current_progress_remaining) + + def add( + self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + """ + Add a new transition to the buffer. + + :param obs: Starting observation of the transition to be stored. + :param next_obs: Destination observation of the transition to be stored. + :param action: Action performed in the transition to be stored. + :param reward: Reward received in the transition to be stored. + :param done: Whether the episode was finished after the transition to be stored. + :param infos: Eventual information given by the environment. + """ + # store transition index with maximum priority in sum tree + self.tree.add(self.max_priority, self.pos) + + # store transition in the buffer + super().add(obs, next_obs, action, reward, done, infos) + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + """ + Sample elements from the prioritized replay buffer. + + :param batch_size: Number of element to sample + :param env:associated gym VecEnv + to normalize the observations/rewards when sampling + :return: a batch of sampled experiences from the buffer. + """ + assert self.buffer_size >= batch_size, "The buffer contains less samples than the batch size requires." + + leaf_nodes_indices = np.zeros(batch_size, dtype=np.uint32) + priorities = np.zeros((batch_size, 1)) + sample_indices = np.zeros(batch_size, dtype=np.uint32) + + # To sample a minibatch of size k, the range [0, total_sum] is divided equally into k ranges. + # Next, a value is uniformly sampled from each range. Finally the transitions that correspond + # to each of these sampled values are retrieved from the tree. + segment_size = self.tree.total_sum / batch_size + for batch_idx in range(batch_size): + # extremes of the current segment + start, end = segment_size * batch_idx, segment_size * (batch_idx + 1) + + # uniformely sample a value from the current segment + cumulative_sum = np.random.uniform(start, end) + + # leaf_node_idx is a index of a sample in the tree, needed further to update priorities + # sample_idx is a sample index in buffer, needed further to sample actual transitions + leaf_node_idx, priority, sample_idx = self.tree.get(cumulative_sum) + + leaf_nodes_indices[batch_idx] = leaf_node_idx + priorities[batch_idx] = priority + sample_indices[batch_idx] = sample_idx + + # probability of sampling transition i as P(i) = p_i^alpha / \sum_{k} p_k^alpha + # where p_i > 0 is the priority of transition i. + probs = priorities / self.tree.total_sum + + # Importance sampling weights. + # All weights w_i were scaled so that max_i w_i = 1. + weights = (self.size() * probs) ** -self.beta + weights = weights / weights.max() + + # TODO: add proper support for multi env + # env_indices = np.random.randint(0, high=self.n_envs, size=(batch_size,)) + env_indices = np.zeros(batch_size, dtype=np.uint32) + + if self.optimize_memory_usage: + next_obs = self._normalize_obs(self.observations[(sample_indices + 1) % self.buffer_size, env_indices, :], env) + else: + next_obs = self._normalize_obs(self.next_observations[sample_indices, env_indices, :], env) + + batch = ( + self._normalize_obs(self.observations[sample_indices, env_indices, :], env), + self.actions[sample_indices, env_indices, :], + next_obs, + self.dones[sample_indices], + self.rewards[sample_indices], + weights, + ) + return ReplayBufferSamples(*tuple(map(self.to_torch, batch)), leaf_nodes_indices) # type: ignore[arg-type,call-arg] + + def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor, progress_remaining: float) -> None: + """ + Update transition priorities. + + :param leaf_nodes_indices: Indices for the leaf nodes to update + (correponding to the transitions) + :param td_errors: New priorities, td error in the case of + proportional prioritized replay buffer. + :param progress_remaining: Current progress remaining (starts from 1 and ends to 0) + to linearly anneal beta from its start value to 1.0 at the end of training + """ + # Update beta schedule + self._current_progress_remaining = progress_remaining + if isinstance(td_errors, th.Tensor): + td_errors = td_errors.detach().cpu().numpy().flatten() + + for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): + # Proportional prioritization priority = (abs(td_error) + eps) ^ alpha + # where eps is a small positive constant that prevents the edge-case of transitions not being + # revisited once their error is zero. (Section 3.3) + priority = (abs(td_error) + self.min_priority) ** self.alpha + self.tree.update(leaf_node_idx, priority) + # Update max priority for new samples + self.max_priority = max(self.max_priority, priority) diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 042c66f9c..8f7b3d1cd 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -52,6 +52,8 @@ class ReplayBufferSamples(NamedTuple): next_observations: th.Tensor dones: th.Tensor rewards: th.Tensor + weights: Union[th.Tensor, float] = 1.0 + leaf_nodes_indices: Optional[np.ndarray] = None class DictReplayBufferSamples(NamedTuple): @@ -60,6 +62,8 @@ class DictReplayBufferSamples(NamedTuple): next_observations: TensorDict dones: th.Tensor rewards: th.Tensor + weights: Union[th.Tensor, float] = 1.0 + leaf_nodes_indices: Optional[np.ndarray] = None class RolloutReturn(NamedTuple): diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 894ed9f04..189f944a0 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -9,6 +9,7 @@ from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy, QNetwork @@ -208,8 +209,20 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Retrieve the q-values for the actions from the replay buffer current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) - # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q_values, target_q_values) + # Special case when using PrioritizedReplayBuffer (PER) + if isinstance(self.replay_buffer, PrioritizedReplayBuffer): + # TD error in absolute value + td_error = th.abs(current_q_values - target_q_values) + # Weighted Huber loss using importance sampling weights + loss = (replay_data.weights * th.where(td_error < 1.0, 0.5 * td_error**2, td_error - 0.5)).mean() + # Update priorities, they will be proportional to the td error + assert replay_data.leaf_nodes_indices is not None, "Node leaf node indices provided" + self.replay_buffer.update_priorities( + replay_data.leaf_nodes_indices, td_error, self._current_progress_remaining + ) + else: + # Compute Huber loss (less sensitive to outliers) + loss = F.smooth_l1_loss(current_q_values, target_q_values) losses.append(loss.item()) # Optimize the policy diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 18171dd21..a59f1fe98 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -8,6 +8,7 @@ from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -109,7 +110,9 @@ def test_replay_buffer_normalization(replay_buffer_cls): assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) -@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) +@pytest.mark.parametrize( + "replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer, PrioritizedReplayBuffer] +) @pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) def test_device_buffer(replay_buffer_cls, device): if device == "cuda" and not th.cuda.is_available(): @@ -120,6 +123,7 @@ def test_device_buffer(replay_buffer_cls, device): DictRolloutBuffer: DummyDictEnv, ReplayBuffer: DummyEnv, DictReplayBuffer: DummyDictEnv, + PrioritizedReplayBuffer: DummyEnv, }[replay_buffer_cls] env = make_vec_env(env) @@ -141,8 +145,8 @@ def test_device_buffer(replay_buffer_cls, device): if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: # get returns an iterator over minibatches data = buffer.get(50) - elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: - data = [buffer.sample(50)] + elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer, PrioritizedReplayBuffer]: + data = buffer.sample(50) # Check that all data are on the desired device desired_device = get_device(device).type diff --git a/tests/test_run.py b/tests/test_run.py index 4acabb692..390d53558 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -6,6 +6,7 @@ from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise +from stable_baselines3.common.prioritized_replay_buffer import PrioritizedReplayBuffer normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) @@ -101,7 +102,8 @@ def test_n_critics(n_critics): model.learn(total_timesteps=200) -def test_dqn(): +@pytest.mark.parametrize("replay_buffer_class", [None, PrioritizedReplayBuffer]) +def test_dqn(replay_buffer_class): model = DQN( "MlpPolicy", "CartPole-v1", @@ -110,6 +112,7 @@ def test_dqn(): buffer_size=500, learning_rate=3e-4, verbose=1, + replay_buffer_class=replay_buffer_class, ) model.learn(total_timesteps=200)