diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f47ec92..a9f5d769 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0) - Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9). +### Fixed +- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN) + ### Removed - Remove OpenAI Gym (`gym`) from dependencies and source code. **skrl** continues to support gym environments, it is just not installed as part of the library. If it is needed, it needs to be installed manually. diff --git a/docs/source/api/agents/ddqn.rst b/docs/source/api/agents/ddqn.rst index 5208192e..095cd967 100644 --- a/docs/source/api/agents/ddqn.rst +++ b/docs/source/api/agents/ddqn.rst @@ -40,10 +40,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`Q' \leftarrow Q_{\phi_{target}}(s')` | :math:`Q_{_{target}} \leftarrow Q'[\underset{a}{\arg\max} \; Q_\phi(s')] \qquad` :gray:`# the only difference with DQN` diff --git a/docs/source/api/agents/dqn.rst b/docs/source/api/agents/dqn.rst index 74e2ca04..d5d72e7b 100644 --- a/docs/source/api/agents/dqn.rst +++ b/docs/source/api/agents/dqn.rst @@ -40,10 +40,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`Q' \leftarrow Q_{\phi_{target}}(s')` | :math:`Q_{_{target}} \leftarrow \underset{a}{\max} \; Q' \qquad` :gray:`# the only difference with DDQN` diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst index b465acbb..cf1e4265 100644 --- a/docs/source/api/agents/sac.rst +++ b/docs/source/api/agents/sac.rst @@ -34,10 +34,10 @@ Learning algorithm | | :literal:`_update(...)` -| :green:`# sample a batch from memory` -| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# gradient steps` | **FOR** each gradient step up to :guilabel:`gradient_steps` **DO** +| :green:`# sample a batch from memory` +| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size` | :green:`# compute target values` | :math:`a',\; logp' \leftarrow \pi_\theta(s')` | :math:`Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')` diff --git a/skrl/agents/jax/dqn/ddqn.py b/skrl/agents/jax/dqn/ddqn.py index 522cac95..76868e68 100644 --- a/skrl/agents/jax/dqn/ddqn.py +++ b/skrl/agents/jax/dqn/ddqn.py @@ -336,13 +336,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/dqn/dqn.py b/skrl/agents/jax/dqn/dqn.py index ac0afa6f..30629c09 100644 --- a/skrl/agents/jax/dqn/dqn.py +++ b/skrl/agents/jax/dqn/dqn.py @@ -333,13 +333,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/sac/sac.py b/skrl/agents/jax/sac/sac.py index 54512df2..7656c26d 100644 --- a/skrl/agents/jax/sac/sac.py +++ b/skrl/agents/jax/sac/sac.py @@ -396,13 +396,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/jax/td3/td3.py b/skrl/agents/jax/td3/td3.py index 38484ea0..23f4885a 100644 --- a/skrl/agents/jax/td3/td3.py +++ b/skrl/agents/jax/td3/td3.py @@ -435,6 +435,7 @@ def _update(self, timestep: int, timesteps: int) -> None: # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 6e0fd1c9..36a98fee 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -360,18 +360,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index dbc0956d..d7e93886 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -283,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index 0a8f9ade..03ffa320 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -283,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0] + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index 6c58a51d..78f4a556 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -307,12 +307,13 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ + # gradient steps for gradient_step in range(self._gradient_steps): # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index 1dd2a401..4d8764b4 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -344,18 +344,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True) diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index 5d5d8d0a..aeb2fd78 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -382,18 +382,19 @@ def _update(self, timestep: int, timesteps: int) -> None: :param timesteps: Number of timesteps :type timesteps: int """ - # sample a batch from memory - sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ - self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] - - rnn_policy = {} - if self._rnn: - sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] - rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} # gradient steps for gradient_step in range(self._gradient_steps): + # sample a batch from memory + sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ + self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0] + + rnn_policy = {} + if self._rnn: + sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0] + rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones} + sampled_states = self._state_preprocessor(sampled_states, train=True) sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)