Skip to content

Commit

Permalink
Fix Sampling inside gradient loop issue (#183)
Browse files Browse the repository at this point in the history
* Move sample inside gradient step loop in TD3 (RNN), DDPG (RNN), SAC, SAC (RNN), DQN and DDQN

* Update CHANGELOG.md

* Apply format

---------

Co-authored-by: Deniz Seven <[email protected]>
Co-authored-by: Toni-SM <[email protected]>
  • Loading branch information
3 people authored Nov 3, 2024
1 parent 5fce807 commit 9252ec9
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 46 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Make flattened tensor storage in memory the default option (revert changed introduced in version 1.3.0)
- Drop support for PyTorch versions prior to 1.10 (the previous supported version was 1.9).

### Fixed
- Moved the batch sampling inside gradient step loop for DQN, DDQN, DDPG (RNN), TD3 (RNN), SAC and SAC (RNN)

### Removed
- Remove OpenAI Gym (`gym`) from dependencies and source code. **skrl** continues to support gym environments,
it is just not installed as part of the library. If it is needed, it needs to be installed manually.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/ddqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`Q' \leftarrow Q_{\phi_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow Q'[\underset{a}{\arg\max} \; Q_\phi(s')] \qquad` :gray:`# the only difference with DQN`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`Q' \leftarrow Q_{\phi_{target}}(s')`
| :math:`Q_{_{target}} \leftarrow \underset{a}{\max} \; Q' \qquad` :gray:`# the only difference with DDQN`
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/agents/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ Learning algorithm

|
| :literal:`_update(...)`
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# gradient steps`
| **FOR** each gradient step up to :guilabel:`gradient_steps` **DO**
| :green:`# sample a batch from memory`
| [:math:`s, a, r, s', d`] :math:`\leftarrow` states, actions, rewards, next_states, dones of size :guilabel:`batch_size`
| :green:`# compute target values`
| :math:`a',\; logp' \leftarrow \pi_\theta(s')`
| :math:`Q_{1_{target}} \leftarrow Q_{{\phi 1}_{target}}(s', a')`
Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,13 +333,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/jax/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
1 change: 1 addition & 0 deletions skrl/agents/jax/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def _update(self, timestep: int, timesteps: int) -> None:

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
Expand Down
17 changes: 9 additions & 8 deletions skrl/agents/torch/ddpg/ddpg_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,18 +360,19 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/torch/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
7 changes: 4 additions & 3 deletions skrl/agents/torch/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,14 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self.tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
3 changes: 2 additions & 1 deletion skrl/agents/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,13 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)
Expand Down
17 changes: 9 additions & 8 deletions skrl/agents/torch/sac/sac_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,18 +344,19 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down
17 changes: 9 additions & 8 deletions skrl/agents/torch/td3/td3_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,18 +382,19 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size, sequence_length=self._rnn_sequence_length)[0]

rnn_policy = {}
if self._rnn:
sampled_rnn = self.memory.sample_by_index(names=self._rnn_tensors_names, indexes=self.memory.get_sampling_indexes())[0]
rnn_policy = {"rnn": [s.transpose(0, 1) for s in sampled_rnn], "terminated": sampled_dones}

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down

0 comments on commit 9252ec9

Please sign in to comment.