Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix type annotation bundle (SAC, TD3, A2C, PPO, base class) #1436

Merged
merged 7 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ Deprecations:

Others:
^^^^^^^
- Fixed ``stable_baselines3/a2c/*.py`` type hints
- Fixed ``stable_baselines3/ppo/*.py`` type hints
- Fixed ``stable_baselines3/sac/*.py`` type hints
- Fixed ``stable_baselines3/td3/*.py`` type hints
- Fixed ``stable_baselines3/common/base_class.py`` type hints
- Upgraded docker images to use mamba/micromamba and CUDA 11.7
- Updated env checker to reflect what subset of Gymnasium is supported and improve GoalEnv checks
- Improve type annotation of wrappers
Expand Down
10 changes: 1 addition & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,14 @@ ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
exclude = """(?x)(
stable_baselines3/a2c/a2c.py$
| stable_baselines3/common/base_class.py$
| stable_baselines3/common/buffers.py$
stable_baselines3/common/buffers.py$
| stable_baselines3/common/callbacks.py$
| stable_baselines3/common/distributions.py$
| stable_baselines3/common/envs/bit_flipping_env.py$
| stable_baselines3/common/envs/identity_env.py$
| stable_baselines3/common/envs/multi_input_envs.py$
| stable_baselines3/common/logger.py$
| stable_baselines3/common/off_policy_algorithm.py$
| stable_baselines3/common/on_policy_algorithm.py$
| stable_baselines3/common/policies.py$
| stable_baselines3/common/save_util.py$
| stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$
Expand All @@ -62,11 +59,6 @@ exclude = """(?x)(
| stable_baselines3/common/vec_env/vec_transpose.py$
| stable_baselines3/common/vec_env/vec_video_recorder.py$
| stable_baselines3/her/her_replay_buffer.py$
| stable_baselines3/ppo/ppo.py$
| stable_baselines3/sac/policies.py$
| stable_baselines3/sac/sac.py$
| stable_baselines3/td3/policies.py$
| stable_baselines3/td3/td3.py$
| tests/test_logger.py$
| tests/test_train_eval_mode.py$
)"""
Expand Down
73 changes: 43 additions & 30 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space, is_image_space_channels_first
from stable_baselines3.common.save_util import load_from_zip_file, recursive_getattr, recursive_setattr, save_to_zip_file
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule, TensorDict
from stable_baselines3.common.utils import (
check_for_correct_spaces,
get_device,
Expand All @@ -44,21 +44,22 @@
SelfBaseAlgorithm = TypeVar("SelfBaseAlgorithm", bound="BaseAlgorithm")


def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymEnv]:
def maybe_make_env(env: Union[GymEnv, str], verbose: int) -> GymEnv:
"""If env is a string, make the environment; otherwise, return env.

:param env: The environment to learn from.
:param verbose: Verbosity level: 0 for no output, 1 for indicating if envrironment is created
:return A Gym (vector) environment.
"""
if isinstance(env, str):
env_id = env
if verbose >= 1:
print(f"Creating environment from the given name '{env}'")
print(f"Creating environment from the given name '{env_id}'")
# Set render_mode to `rgb_array` as default, so we can record video
try:
env = gym.make(env, render_mode="rgb_array")
env = gym.make(env_id, render_mode="rgb_array")
except TypeError:
env = gym.make(env)
env = gym.make(env_id)
return env


Expand Down Expand Up @@ -95,6 +96,11 @@ class BaseAlgorithm(ABC):
# Policy aliases (see _get_policy_from_name())
policy_aliases: Dict[str, Type[BasePolicy]] = {}
policy: BasePolicy
observation_space: spaces.Space
action_space: spaces.Space
n_envs: int
lr_schedule: Schedule
_logger: Logger

def __init__(
self,
Expand All @@ -111,8 +117,8 @@ def __init__(
seed: Optional[int] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,
):
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
) -> None:
if isinstance(policy, str):
self.policy_class = self._get_policy_from_name(policy)
else:
Expand All @@ -122,25 +128,19 @@ def __init__(
if verbose >= 1:
print(f"Using {self.device} device")

self.env = None # type: Optional[GymEnv]
# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)
self.verbose = verbose
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
self.observation_space: spaces.Space
self.action_space: spaces.Space
self.n_envs: int

self.num_timesteps = 0
# Used for updating schedules
self._total_timesteps = 0
# Used for computing fps, it is updated at each call of learn()
self._num_timesteps_at_start = 0
self.seed = seed
self.action_noise: Optional[ActionNoise] = None
self.start_time = None
self.start_time = 0.0
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Schedule]
self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]]
self._last_episode_starts = None # type: Optional[np.ndarray]
# When using VecNormalize:
Expand All @@ -151,17 +151,17 @@ def __init__(
self.sde_sample_freq = sde_sample_freq
# Track the training progress remaining (from 1 to 0)
# this is used to update the learning rate
self._current_progress_remaining = 1
self._current_progress_remaining = 1.0
# Buffers for logging
self._stats_window_size = stats_window_size
self.ep_info_buffer = None # type: Optional[deque]
self.ep_success_buffer = None # type: Optional[deque]
# For logging (and TD3 delayed updates)
self._n_updates = 0 # type: int
# The logger object
self._logger = None # type: Logger
# Whether the user passed a custom logger or not
self._custom_logger = False
self.env: Optional[VecEnv] = None
self._vec_normalize_env: Optional[VecNormalize] = None

# Create and wrap the env if needed
if env is not None:
Expand All @@ -173,6 +173,9 @@ def __init__(
self.n_envs = env.num_envs
self.env = env

# get VecNormalize object if needed
self._vec_normalize_env = unwrap_vec_normalize(env)

if supported_action_spaces is not None:
assert isinstance(self.action_space, supported_action_spaces), (
f"The algorithm only supports {supported_action_spaces} as action spaces "
Expand Down Expand Up @@ -217,7 +220,7 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
env = Monitor(env)
if verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
env = DummyVecEnv([lambda: env])
env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value]

# Make sure that dict-spaces are not nested (not supported)
check_for_nested_spaces(env.observation_space)
Expand All @@ -230,11 +233,11 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve
# the other channel last), VecTransposeImage will throw an error
for space in env.observation_space.spaces.values():
wrap_with_vectranspose = wrap_with_vectranspose or (
is_image_space(space) and not is_image_space_channels_first(space)
is_image_space(space) and not is_image_space_channels_first(space) # type: ignore[arg-type]
)
else:
wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first(
env.observation_space
env.observation_space # type: ignore[arg-type]
)

if wrap_with_vectranspose:
Expand Down Expand Up @@ -416,7 +419,10 @@ def _setup_learn(

# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
self._last_obs = self.env.reset() # pytype: disable=annotation-type-mismatch
assert self.env is not None
# pytype: disable=annotation-type-mismatch
self._last_obs = self.env.reset() # type: ignore[assignment]
# pytype: enable=annotation-type-mismatch
self._last_episode_starts = np.ones((self.env.num_envs,), dtype=bool)
# Retrieve unnormalized observation for saving into the buffer
if self._vec_normalize_env is not None:
Expand All @@ -439,6 +445,9 @@ def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.nd
:param infos: List of additional information about the transition.
:param dones: Termination signals
"""
assert self.ep_info_buffer is not None
assert self.ep_success_buffer is not None

if dones is None:
dones = np.array([False] * len(infos))
for idx, info in enumerate(infos):
Expand Down Expand Up @@ -562,7 +571,7 @@ def set_random_seed(self, seed: Optional[int] = None) -> None:

def set_parameters(
self,
load_path_or_dict: Union[str, Dict[str, Dict]],
load_path_or_dict: Union[str, TensorDict],
exact_match: bool = True,
device: Union[th.device, str] = "auto",
) -> None:
Expand All @@ -578,7 +587,7 @@ def set_parameters(
can be used to update only specific parameters.
:param device: Device on which the code should run.
"""
params = None
params = {}
if isinstance(load_path_or_dict, dict):
params = load_path_or_dict
else:
Expand Down Expand Up @@ -616,7 +625,7 @@ def set_parameters(
#
# Solution: Just load the state-dict as is, and trust
# the user has provided a sensible state dictionary.
attr.load_state_dict(params[name])
attr.load_state_dict(params[name]) # type: ignore[arg-type]
else:
# Assume attr is th.nn.Module
attr.load_state_dict(params[name], strict=exact_match)
Expand Down Expand Up @@ -674,6 +683,9 @@ def load( # noqa: C901
print_system_info=print_system_info,
)

assert data is not None, "No data found in the saved file"
assert params is not None, "No params found in the saved file"

# Remove stored device information and replace with ours
if "policy_kwargs" in data:
if "device" in data["policy_kwargs"]:
Expand Down Expand Up @@ -714,13 +726,14 @@ def load( # noqa: C901
if "env" in data:
env = data["env"]

# noinspection PyArgumentList
model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
# pytype: disable=not-instantiable,wrong-keyword-args
model = cls(
policy=data["policy_class"],
env=env,
device=device,
_init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args
_init_setup_model=False, # type: ignore[call-arg]
)
# pytype: enable=not-instantiable,wrong-keyword-args

# load parameters
model.__dict__.update(data)
Expand Down Expand Up @@ -758,12 +771,12 @@ def load( # noqa: C901
continue
# Set the data attribute directly to avoid issue when using optimizers
# See https://github.com/DLR-RM/stable-baselines3/issues/391
recursive_setattr(model, name + ".data", pytorch_variables[name].data)
recursive_setattr(model, f"{name}.data", pytorch_variables[name].data)

# Sample gSDE exploration matrix, so it uses the right device
# see issue #44
if model.use_sde:
model.policy.reset_noise() # pytype: disable=attribute-error
model.policy.reset_noise() # type: ignore[operator] # pytype: disable=attribute-error
return model

def get_parameters(self) -> Dict[str, Dict]:
Expand Down
11 changes: 9 additions & 2 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,15 @@ class RolloutBuffer(BaseBuffer):
:param n_envs: Number of parallel environments
"""

observations: np.ndarray
actions: np.ndarray
rewards: np.ndarray
advantages: np.ndarray
returns: np.ndarray
episode_starts: np.ndarray
log_probs: np.ndarray
values: np.ndarray

def __init__(
self,
buffer_size: int,
Expand All @@ -348,8 +357,6 @@ def __init__(
super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None
self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None
self.generator_ready = False
self.reset()

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ class ConvertCallback(BaseCallback):
:param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages
"""

def __init__(self, callback: Callable[[Dict[str, Any], Dict[str, Any]], bool], verbose: int = 0):
def __init__(self, callback: Optional[Callable[[Dict[str, Any], Dict[str, Any]], bool]], verbose: int = 0):
super().__init__(verbose)
self.callback = callback

Expand Down
5 changes: 3 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class OffPolicyAlgorithm(BaseAlgorithm):
:param supported_action_spaces: The action spaces supported by the algorithm.
"""

actor: th.nn.Module

def __init__(
self,
policy: Union[str, Type[BasePolicy]],
Expand Down Expand Up @@ -129,15 +131,14 @@ def __init__(
self.gradient_steps = gradient_steps
self.action_noise = action_noise
self.optimize_memory_usage = optimize_memory_usage
self.replay_buffer: Optional[ReplayBuffer] = None
self.replay_buffer_class = replay_buffer_class
self.replay_buffer_kwargs = replay_buffer_kwargs or {}
self._episode_storage = None

# Save train freq parameter, will be converted later to TrainFreq object
self.train_freq = train_freq

self.actor = None # type: Optional[th.nn.Module]
self.replay_buffer: Optional[ReplayBuffer] = None
# Update policy keyword arguments
if sde_support:
self.policy_kwargs["use_sde"] = self.use_sde
Expand Down
Loading