@@ -89,7 +89,7 @@ def __init__(
89
89
tensorboard_log : Optional [str ] = None ,
90
90
verbose : int = 0 ,
91
91
device : Union [th .device , str ] = "auto" ,
92
- support_multi_env : bool = False ,
92
+ support_multi_env : bool = True ,
93
93
create_eval_env : bool = False ,
94
94
monitor_wrapper : bool = True ,
95
95
seed : Optional [int ] = None ,
@@ -98,7 +98,6 @@ def __init__(
98
98
use_sde_at_warmup : bool = False ,
99
99
sde_support : bool = True ,
100
100
):
101
-
102
101
super (OffPolicyAlgorithm , self ).__init__ (
103
102
policy = policy ,
104
103
env = env ,
@@ -125,6 +124,7 @@ def __init__(
125
124
self .n_episodes_rollout = n_episodes_rollout
126
125
self .action_noise = action_noise
127
126
self .optimize_memory_usage = optimize_memory_usage
127
+ self .n_envs = env .num_envs
128
128
129
129
if train_freq > 0 and n_episodes_rollout > 0 :
130
130
warnings .warn (
@@ -152,6 +152,7 @@ def _setup_model(self) -> None:
152
152
self .observation_space ,
153
153
self .action_space ,
154
154
self .device ,
155
+ self .n_envs ,
155
156
optimize_memory_usage = self .optimize_memory_usage ,
156
157
)
157
158
self .policy = self .policy_class (
@@ -273,7 +274,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None:
273
274
raise NotImplementedError ()
274
275
275
276
def _sample_action (
276
- self , learning_starts : int , action_noise : Optional [ActionNoise ] = None
277
+ self , learning_starts : int , num_envs = 1 , action_noise : Optional [ActionNoise ] = None
277
278
) -> Tuple [np .ndarray , np .ndarray ]:
278
279
"""
279
280
Sample an action according to the exploration policy.
@@ -292,7 +293,7 @@ def _sample_action(
292
293
# Select action randomly or according to policy
293
294
if self .num_timesteps < learning_starts and not (self .use_sde and self .use_sde_at_warmup ):
294
295
# Warmup phase
295
- unscaled_action = np .array ([self .action_space .sample ()])
296
+ unscaled_action = np .array ([ self .action_space .sample () for i in range ( num_envs ) ])
296
297
else :
297
298
# Note: when using continuous actions,
298
299
# we assume that the policy uses tanh to scale the action
@@ -377,10 +378,10 @@ def collect_rollouts(
377
378
total_steps , total_episodes = 0 , 0
378
379
379
380
assert isinstance (env , VecEnv ), "You must pass a VecEnv"
380
- assert env .num_envs == 1 , "OffPolicyAlgorithm only support single environment"
381
+ # assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment"
381
382
382
383
if self .use_sde :
383
- self .actor .reset_noise ()
384
+ self .actor .reset_noise (self . n_envs )
384
385
385
386
callback .on_rollout_start ()
386
387
continue_training = True
@@ -393,13 +394,14 @@ def collect_rollouts(
393
394
394
395
if self .use_sde and self .sde_sample_freq > 0 and total_steps % self .sde_sample_freq == 0 :
395
396
# Sample a new noise matrix
396
- self .actor .reset_noise ()
397
+ self .actor .reset_noise (self . n_envs )
397
398
398
399
# Select action randomly or according to policy
399
- action , buffer_action = self ._sample_action (learning_starts , action_noise )
400
-
400
+ action , buffer_action = self ._sample_action (learning_starts , self .n_envs , action_noise )
401
401
# Rescale and perform action
402
+
402
403
new_obs , reward , done , infos = env .step (action )
404
+ done = np .all (done ) # done only when all threads are done
403
405
404
406
# Give access to local variables
405
407
callback .update_locals (locals ())
@@ -429,9 +431,9 @@ def collect_rollouts(
429
431
if self ._vec_normalize_env is not None :
430
432
self ._last_original_obs = new_obs_
431
433
432
- self .num_timesteps += 1
433
- episode_timesteps += 1
434
- total_steps += 1
434
+ self .num_timesteps += self . n_envs
435
+ episode_timesteps += self . n_envs
436
+ total_steps += self . n_envs
435
437
self ._update_current_progress_remaining (self .num_timesteps , self ._total_timesteps )
436
438
437
439
# For DQN, check if the target network should be updated
@@ -444,6 +446,7 @@ def collect_rollouts(
444
446
break
445
447
446
448
if done :
449
+ print ('Episode Complete' , self ._episode_num )
447
450
total_episodes += 1
448
451
self ._episode_num += 1
449
452
episode_rewards .append (episode_reward )
0 commit comments