Skip to content

Commit

Permalink
Fix: with SAC, a new training batch should be sampled for each gradie…
Browse files Browse the repository at this point in the history
…nt step (#208)

* Fix: with SAC, a new training batch should be sampled for each gradient_step

* Apply format

---------

Co-authored-by: yibo di <[email protected]>
Co-authored-by: Toni-SM <[email protected]>
  • Loading branch information
3 people authored Nov 3, 2024
1 parent ae4e09e commit 5fce807
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions skrl/agents/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,13 @@ def _update(self, timestep: int, timesteps: int) -> None:
:param timesteps: Number of timesteps
:type timesteps: int
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

# gradient steps
for gradient_step in range(self._gradient_steps):

# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

sampled_states = self._state_preprocessor(sampled_states, train=True)
sampled_next_states = self._state_preprocessor(sampled_next_states, train=True)

Expand Down

0 comments on commit 5fce807

Please sign in to comment.