Skip to content

Commit 1579713

Browse files
committed
Vectorized environment support for off policy algorithms (SAC, TD3, DDPG etc)
1 parent a1afc5e commit 1579713

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

debug.log

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
step 48 reward 0.0008 action

stable_baselines3/common/base_class.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
189189

190190
if eval_env is not None:
191191
eval_env = self._wrap_env(eval_env)
192-
assert eval_env.num_envs == 1
192+
# assert eval_env.num_envs == 1
193193
return eval_env
194194

195195
def _setup_lr_schedule(self) -> None:

stable_baselines3/common/buffers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def __init__(
164164
):
165165
super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
166166

167-
assert n_envs == 1, "Replay buffer only support single environment for now"
167+
# assert n_envs == 1, "Replay buffer only support single environment for now"
168168

169169
# Check that the replay buffer can fit into the memory
170170
if psutil is not None:

stable_baselines3/common/off_policy_algorithm.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
tensorboard_log: Optional[str] = None,
9090
verbose: int = 0,
9191
device: Union[th.device, str] = "auto",
92-
support_multi_env: bool = False,
92+
support_multi_env: bool = True,
9393
create_eval_env: bool = False,
9494
monitor_wrapper: bool = True,
9595
seed: Optional[int] = None,
@@ -98,7 +98,6 @@ def __init__(
9898
use_sde_at_warmup: bool = False,
9999
sde_support: bool = True,
100100
):
101-
102101
super(OffPolicyAlgorithm, self).__init__(
103102
policy=policy,
104103
env=env,
@@ -125,6 +124,7 @@ def __init__(
125124
self.n_episodes_rollout = n_episodes_rollout
126125
self.action_noise = action_noise
127126
self.optimize_memory_usage = optimize_memory_usage
127+
self.n_envs = env.num_envs
128128

129129
if train_freq > 0 and n_episodes_rollout > 0:
130130
warnings.warn(
@@ -152,6 +152,7 @@ def _setup_model(self) -> None:
152152
self.observation_space,
153153
self.action_space,
154154
self.device,
155+
self.n_envs,
155156
optimize_memory_usage=self.optimize_memory_usage,
156157
)
157158
self.policy = self.policy_class(
@@ -273,7 +274,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
273274
raise NotImplementedError()
274275

275276
def _sample_action(
276-
self, learning_starts: int, action_noise: Optional[ActionNoise] = None
277+
self, learning_starts: int, num_envs=1, action_noise: Optional[ActionNoise] = None
277278
) -> Tuple[np.ndarray, np.ndarray]:
278279
"""
279280
Sample an action according to the exploration policy.
@@ -292,7 +293,7 @@ def _sample_action(
292293
# Select action randomly or according to policy
293294
if self.num_timesteps < learning_starts and not (self.use_sde and self.use_sde_at_warmup):
294295
# Warmup phase
295-
unscaled_action = np.array([self.action_space.sample()])
296+
unscaled_action = np.array([ self.action_space.sample() for i in range(num_envs) ])
296297
else:
297298
# Note: when using continuous actions,
298299
# we assume that the policy uses tanh to scale the action
@@ -377,10 +378,10 @@ def collect_rollouts(
377378
total_steps, total_episodes = 0, 0
378379

379380
assert isinstance(env, VecEnv), "You must pass a VecEnv"
380-
assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
381+
# assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
381382

382383
if self.use_sde:
383-
self.actor.reset_noise()
384+
self.actor.reset_noise(self.n_envs)
384385

385386
callback.on_rollout_start()
386387
continue_training = True
@@ -393,13 +394,14 @@ def collect_rollouts(
393394

394395
if self.use_sde and self.sde_sample_freq > 0 and total_steps % self.sde_sample_freq == 0:
395396
# Sample a new noise matrix
396-
self.actor.reset_noise()
397+
self.actor.reset_noise(self.n_envs)
397398

398399
# Select action randomly or according to policy
399-
action, buffer_action = self._sample_action(learning_starts, action_noise)
400-
400+
action, buffer_action = self._sample_action(learning_starts, self.n_envs, action_noise)
401401
# Rescale and perform action
402+
402403
new_obs, reward, done, infos = env.step(action)
404+
done = np.all(done) # done only when all threads are done
403405

404406
# Give access to local variables
405407
callback.update_locals(locals())
@@ -429,9 +431,9 @@ def collect_rollouts(
429431
if self._vec_normalize_env is not None:
430432
self._last_original_obs = new_obs_
431433

432-
self.num_timesteps += 1
433-
episode_timesteps += 1
434-
total_steps += 1
434+
self.num_timesteps += self.n_envs
435+
episode_timesteps += self.n_envs
436+
total_steps += self.n_envs
435437
self._update_current_progress_remaining(self.num_timesteps, self._total_timesteps)
436438

437439
# For DQN, check if the target network should be updated
@@ -444,6 +446,7 @@ def collect_rollouts(
444446
break
445447

446448
if done:
449+
print('Episode Complete', self._episode_num )
447450
total_episodes += 1
448451
self._episode_num += 1
449452
episode_rewards.append(episode_reward)

0 commit comments

Comments
 (0)