@@ -158,27 +158,6 @@ def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
158
158
assert eval_env .num_envs == 1
159
159
return eval_env
160
160
161
- def scale_action (self , action : np .ndarray ) -> np .ndarray :
162
- """
163
- Rescale the action from [low, high] to [-1, 1]
164
- (no need for symmetric action space)
165
-
166
- :param action: (np.ndarray) Action to scale
167
- :return: (np.ndarray) Scaled action
168
- """
169
- low , high = self .action_space .low , self .action_space .high
170
- return 2.0 * ((action - low ) / (high - low )) - 1.0
171
-
172
- def unscale_action (self , scaled_action : np .ndarray ) -> np .ndarray :
173
- """
174
- Rescale the action from [-1, 1] to [low, high]
175
- (no need for symmetric action space)
176
-
177
- :param scaled_action: Action to un-scale
178
- """
179
- low , high = self .action_space .low , self .action_space .high
180
- return low + (0.5 * (scaled_action + 1.0 ) * (high - low ))
181
-
182
161
def _setup_lr_schedule (self ) -> None :
183
162
"""Transform to callable if needed."""
184
163
self .lr_schedule = get_schedule_fn (self .learning_rate )
@@ -318,57 +297,6 @@ def learn(self, total_timesteps: int,
318
297
"""
319
298
raise NotImplementedError ()
320
299
321
- @staticmethod
322
- def _is_vectorized_observation (observation : np .ndarray , observation_space : gym .spaces .Space ) -> bool :
323
- """
324
- For every observation type, detects and validates the shape,
325
- then returns whether or not the observation is vectorized.
326
-
327
- :param observation: (np.ndarray) the input observation to validate
328
- :param observation_space: (gym.spaces) the observation space
329
- :return: (bool) whether the given observation is vectorized or not
330
- """
331
- if isinstance (observation_space , gym .spaces .Box ):
332
- if observation .shape == observation_space .shape :
333
- return False
334
- elif observation .shape [1 :] == observation_space .shape :
335
- return True
336
- else :
337
- raise ValueError ("Error: Unexpected observation shape {} for " .format (observation .shape ) +
338
- "Box environment, please use {} " .format (observation_space .shape ) +
339
- "or (n_env, {}) for the observation shape."
340
- .format (", " .join (map (str , observation_space .shape ))))
341
- elif isinstance (observation_space , gym .spaces .Discrete ):
342
- if observation .shape == (): # A numpy array of a number, has shape empty tuple '()'
343
- return False
344
- elif len (observation .shape ) == 1 :
345
- return True
346
- else :
347
- raise ValueError ("Error: Unexpected observation shape {} for " .format (observation .shape ) +
348
- "Discrete environment, please use (1,) or (n_env, 1) for the observation shape." )
349
- # TODO: add support for MultiDiscrete and MultiBinary observation spaces
350
- # elif isinstance(observation_space, gym.spaces.MultiDiscrete):
351
- # if observation.shape == (len(observation_space.nvec),):
352
- # return False
353
- # elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
354
- # return True
355
- # else:
356
- # raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
357
- # "environment, please use ({},) or ".format(len(observation_space.nvec)) +
358
- # "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
359
- # elif isinstance(observation_space, gym.spaces.MultiBinary):
360
- # if observation.shape == (observation_space.n,):
361
- # return False
362
- # elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
363
- # return True
364
- # else:
365
- # raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
366
- # "environment, please use ({},) or ".format(observation_space.n) +
367
- # "(n_env, {}) for the observation shape.".format(observation_space.n))
368
- else :
369
- raise ValueError ("Error: Cannot determine if the observation is vectorized with the space type {}."
370
- .format (observation_space ))
371
-
372
300
def predict (self , observation : np .ndarray ,
373
301
state : Optional [np .ndarray ] = None ,
374
302
mask : Optional [np .ndarray ] = None ,
@@ -383,36 +311,7 @@ def predict(self, observation: np.ndarray,
383
311
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
384
312
(used in recurrent policies)
385
313
"""
386
- # TODO: move this block to BasePolicy
387
- # if state is None:
388
- # state = self.initial_state
389
- # if mask is None:
390
- # mask = [False for _ in range(self.n_envs)]
391
- observation = np .array (observation )
392
- vectorized_env = self ._is_vectorized_observation (observation , self .observation_space )
393
-
394
- observation = observation .reshape ((- 1 ,) + self .observation_space .shape )
395
- observation = th .as_tensor (observation ).to (self .device )
396
- with th .no_grad ():
397
- actions = self .policy .predict (observation , deterministic = deterministic )
398
- # Convert to numpy
399
- actions = actions .cpu ().numpy ()
400
-
401
- # Rescale to proper domain when using squashing
402
- if isinstance (self .action_space , gym .spaces .Box ) and self .policy .squash_output :
403
- actions = self .unscale_action (actions )
404
-
405
- clipped_actions = actions
406
- # Clip the actions to avoid out of bound error when using gaussian distribution
407
- if isinstance (self .action_space , gym .spaces .Box ) and not self .policy .squash_output :
408
- clipped_actions = np .clip (actions , self .action_space .low , self .action_space .high )
409
-
410
- if not vectorized_env :
411
- if state is not None :
412
- raise ValueError ("Error: The environment must be vectorized when using recurrent policies." )
413
- clipped_actions = clipped_actions [0 ]
414
-
415
- return clipped_actions , state
314
+ return self .policy .predict (observation , state , mask , deterministic )
416
315
417
316
@classmethod
418
317
def load (cls , load_path : str , env : Optional [GymEnv ] = None , ** kwargs ):
@@ -484,10 +383,7 @@ def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[D
484
383
raise ValueError (f"Error: the file { load_path } could not be found" )
485
384
486
385
# set device to cpu if cuda is not available
487
- if th .cuda .is_available ():
488
- device = th .device ('cuda' )
489
- else :
490
- device = th .device ('cpu' )
386
+ device = th .device ('cuda' ) if th .cuda .is_available () else th .device ('cpu' )
491
387
492
388
# Open the zip archive and load data
493
389
try :
@@ -534,20 +430,6 @@ def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[D
534
430
# load the parameters with the right `map_location`
535
431
params [os .path .splitext (file_path )[0 ]] = th .load (file_content , map_location = device )
536
432
537
- # for backward compatibility
538
- if params .get ('params' ) is not None :
539
- params_copy = {}
540
- for name in params :
541
- if name == 'params' :
542
- params_copy ['policy' ] = params [name ]
543
- elif name == 'opt' :
544
- params_copy ['policy.optimizer' ] = params [name ]
545
- # Special case for SAC
546
- elif name == 'ent_coef_optimizer' :
547
- params_copy [name ] = params [name ]
548
- else :
549
- params_copy [name + '.optimizer' ] = params [name ]
550
- params = params_copy
551
433
except zipfile .BadZipFile :
552
434
# load_path wasn't a zip file
553
435
raise ValueError (f"Error: the file { load_path } wasn't a zip-file" )
@@ -925,7 +807,7 @@ def collect_rollouts(self,
925
807
unscaled_action , _ = self .predict (obs , deterministic = False )
926
808
927
809
# Rescale the action from [low, high] to [-1, 1]
928
- scaled_action = self .scale_action (unscaled_action )
810
+ scaled_action = self .policy . scale_action (unscaled_action )
929
811
930
812
if self .use_sde :
931
813
# When using SDE, the action can be out of bounds
@@ -941,7 +823,7 @@ def collect_rollouts(self,
941
823
clipped_action = np .clip (clipped_action + action_noise (), - 1 , 1 )
942
824
943
825
# Rescale and perform action
944
- new_obs , reward , done , infos = env .step (self .unscale_action (clipped_action ))
826
+ new_obs , reward , done , infos = env .step (self .policy . unscale_action (clipped_action ))
945
827
946
828
# Only stop training if return value is False, not when it is None.
947
829
if callback .on_step () is False :
0 commit comments