diff --git a/docs/guide/callbacks.rst b/docs/guide/callbacks.rst index e08c00367..5b2cfaee5 100644 --- a/docs/guide/callbacks.rst +++ b/docs/guide/callbacks.rst @@ -119,23 +119,18 @@ A child callback is for instance :ref:`StopTrainingOnRewardThreshold bool: - if self.callback is not None: - return self.callback() - return True - + return self.callback() Callback Collection diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index fd209dea8..d09c00705 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,40 @@ Changelog ========== + +Release 2.2.0a0 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ +- Fixed ``stable_baselines3/common/callbacks.py`` type hints +- Fixed ``stable_baselines3/common/utils.py`` type hints +- Fixed ``stable_baselines3/common/vec_envs/vec_transpose.py`` type hints +- Fixed ``stable_baselines3/common/vec_env/vec_video_recorder.py`` type hints +- Fixed ``stable_baselines3/common/save_util.py`` type hints + +Documentation: +^^^^^^^^^^^^^^ + + Release 2.1.0 (2023-08-17) -------------------------- diff --git a/pyproject.toml b/pyproject.toml index 518fd26b7..1c1837ace 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,16 +39,11 @@ follow_imports = "silent" show_error_codes = true exclude = """(?x)( stable_baselines3/common/buffers.py$ - | stable_baselines3/common/callbacks.py$ | stable_baselines3/common/distributions.py$ | stable_baselines3/common/off_policy_algorithm.py$ | stable_baselines3/common/policies.py$ - | stable_baselines3/common/save_util.py$ - | stable_baselines3/common/utils.py$ | stable_baselines3/common/vec_env/__init__.py$ | stable_baselines3/common/vec_env/vec_normalize.py$ - | stable_baselines3/common/vec_env/vec_transpose.py$ - | stable_baselines3/common/vec_env/vec_video_recorder.py$ | stable_baselines3/her/her_replay_buffer.py$ | tests/test_logger.py$ | tests/test_train_eval_mode.py$ diff --git a/stable_baselines3/common/callbacks.py b/stable_baselines3/common/callbacks.py index c9b8a3367..f16b57976 100644 --- a/stable_baselines3/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -34,12 +34,9 @@ class BaseCallback(ABC): # The RL model # Type hint as string to avoid circular import model: "base_class.BaseAlgorithm" - logger: Logger def __init__(self, verbose: int = 0): super().__init__() - # An alias for self.model.get_env(), the environment used for training - self.training_env = None # type: Union[gym.Env, VecEnv, None] # Number of time the callback was called self.n_calls = 0 # type: int # n_envs * n times env.step() was called @@ -51,6 +48,18 @@ def __init__(self, verbose: int = 0): # to have access to the parent object self.parent = None # type: Optional[BaseCallback] + @property + def training_env(self) -> VecEnv: + training_env = self.model.get_env() + assert ( + training_env is not None + ), "`model.get_env()` returned None, you must initialize the model with an environment to use callbacks" + return training_env + + @property + def logger(self) -> Logger: + return self.model.logger + # Type hint as string to avoid circular import def init_callback(self, model: "base_class.BaseAlgorithm") -> None: """ @@ -58,8 +67,6 @@ def init_callback(self, model: "base_class.BaseAlgorithm") -> None: RL model and the training environment for convenience. """ self.model = model - self.training_env = model.get_env() - self.logger = model.logger self._init_callback() def _init_callback(self) -> None: @@ -147,6 +154,7 @@ def __init__(self, callback: Optional[BaseCallback] = None, verbose: int = 0): self.callback = callback # Give access to the parent if callback is not None: + assert self.callback is not None self.callback.parent = self def init_callback(self, model: "base_class.BaseAlgorithm") -> None: @@ -291,14 +299,14 @@ def _on_step(self) -> bool: if self.save_replay_buffer and hasattr(self.model, "replay_buffer") and self.model.replay_buffer is not None: # If model has a replay buffer, save it too replay_buffer_path = self._checkpoint_path("replay_buffer_", extension="pkl") - self.model.save_replay_buffer(replay_buffer_path) + self.model.save_replay_buffer(replay_buffer_path) # type: ignore[attr-defined] if self.verbose > 1: print(f"Saving model replay buffer checkpoint to {replay_buffer_path}") if self.save_vecnormalize and self.model.get_vec_normalize_env() is not None: # Save the VecNormalize statistics vec_normalize_path = self._checkpoint_path("vecnormalize_", extension="pkl") - self.model.get_vec_normalize_env().save(vec_normalize_path) + self.model.get_vec_normalize_env().save(vec_normalize_path) # type: ignore[union-attr] if self.verbose >= 2: print(f"Saving model VecNormalize to {vec_normalize_path}") @@ -382,7 +390,7 @@ def __init__( # Convert to VecEnv for consistency if not isinstance(eval_env, VecEnv): - eval_env = DummyVecEnv([lambda: eval_env]) + eval_env = DummyVecEnv([lambda: eval_env]) # type: ignore[list-item, return-value] self.eval_env = eval_env self.best_model_save_path = best_model_save_path @@ -390,12 +398,12 @@ def __init__( if log_path is not None: log_path = os.path.join(log_path, "evaluations") self.log_path = log_path - self.evaluations_results = [] - self.evaluations_timesteps = [] - self.evaluations_length = [] + self.evaluations_results: List[List[float]] = [] + self.evaluations_timesteps: List[int] = [] + self.evaluations_length: List[List[int]] = [] # For computing success rate - self._is_success_buffer = [] - self.evaluations_successes = [] + self._is_success_buffer: List[bool] = [] + self.evaluations_successes: List[List[bool]] = [] def _init_callback(self) -> None: # Does not work in some corner cases, where the wrapper is not the same @@ -458,6 +466,8 @@ def _on_step(self) -> bool: ) if self.log_path is not None: + assert isinstance(episode_rewards, list) + assert isinstance(episode_lengths, list) self.evaluations_timesteps.append(self.num_timesteps) self.evaluations_results.append(episode_rewards) self.evaluations_length.append(episode_lengths) @@ -478,7 +488,7 @@ def _on_step(self) -> bool: mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards) mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths) - self.last_mean_reward = mean_reward + self.last_mean_reward = float(mean_reward) if self.verbose >= 1: print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}") @@ -502,7 +512,7 @@ def _on_step(self) -> bool: print("New best mean reward!") if self.best_model_save_path is not None: self.model.save(os.path.join(self.best_model_save_path, "best_model")) - self.best_mean_reward = mean_reward + self.best_mean_reward = float(mean_reward) # Trigger callback on new best model, if needed if self.callback_on_new_best is not None: continue_training = self.callback_on_new_best.on_step() @@ -536,12 +546,14 @@ class StopTrainingOnRewardThreshold(BaseCallback): threshold reached """ + parent: EvalCallback + def __init__(self, reward_threshold: float, verbose: int = 0): super().__init__(verbose=verbose) self.reward_threshold = reward_threshold def _on_step(self) -> bool: - assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used " "with an ``EvalCallback``" + assert self.parent is not None, "``StopTrainingOnMinimumReward`` callback must be used with an ``EvalCallback``" # Convert np.bool_ to bool, otherwise callback() is False won't work continue_training = bool(self.parent.best_mean_reward < self.reward_threshold) if self.verbose >= 1 and not continue_training: @@ -630,6 +642,8 @@ class StopTrainingOnNoModelImprovement(BaseCallback): :param verbose: Verbosity level: 0 for no output, 1 for indicating when training ended because no new best model """ + parent: EvalCallback + def __init__(self, max_no_improvement_evals: int, min_evals: int = 0, verbose: int = 0): super().__init__(verbose=verbose) self.max_no_improvement_evals = max_no_improvement_evals @@ -666,6 +680,8 @@ class ProgressBarCallback(BaseCallback): using tqdm and rich packages. """ + pbar: tqdm # pytype: disable=invalid-annotation + def __init__(self) -> None: super().__init__() if tqdm is None: @@ -674,7 +690,6 @@ def __init__(self) -> None: "It is included if you install stable-baselines with the extra packages: " "`pip install stable-baselines3[extra]`" ) - self.pbar = None def _on_training_start(self) -> None: # Initialize progress bar diff --git a/stable_baselines3/common/save_util.py b/stable_baselines3/common/save_util.py index 3c01a3f26..b24b4654b 100644 --- a/stable_baselines3/common/save_util.py +++ b/stable_baselines3/common/save_util.py @@ -178,7 +178,9 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No @functools.singledispatch -def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None): +def open_path( + path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verbose: int = 0, suffix: Optional[str] = None +) -> Union[io.BufferedWriter, io.BufferedReader, io.BytesIO]: """ Opens a path for reading or writing with a preferred suffix and raises debug information. If the provided path is a derivative of io.BufferedIOBase it ensures that the file @@ -201,18 +203,21 @@ def open_path(path: Union[str, pathlib.Path, io.BufferedIOBase], mode: str, verb is not None, we attempt to open the path with the suffix. :return: """ - if not isinstance(path, io.BufferedIOBase): - raise TypeError("Path parameter has invalid type.", io.BufferedIOBase) + # Note(antonin): the true annotation should be IO[bytes] + # but there is not easy way to check that + allowed_types = (io.BufferedWriter, io.BufferedReader, io.BytesIO) + if not isinstance(path, allowed_types): + raise TypeError(f"Path {path} parameter has invalid type: expected one of {allowed_types}.") if path.closed: - raise ValueError("File stream is closed.") + raise ValueError(f"File stream {path} is closed.") mode = mode.lower() try: mode = {"write": "w", "read": "r", "w": "w", "r": "r"}[mode] except KeyError as e: raise ValueError("Expected mode to be either 'w' or 'r'.") from e if ("w" == mode) and not path.writable() or ("r" == mode) and not path.readable(): - e1 = "writable" if "w" == mode else "readable" - raise ValueError(f"Expected a {e1} file.") + error_msg = "writable" if "w" == mode else "readable" + raise ValueError(f"Expected a {error_msg} file.") return path @@ -231,7 +236,7 @@ def open_path_str(path: str, mode: str, verbose: int = 0, suffix: Optional[str] is not None, we attempt to open the path with the suffix. :return: """ - return open_path(pathlib.Path(path), mode, verbose, suffix) + return open_path_pathlib(pathlib.Path(path), mode, verbose, suffix) @open_path.register(pathlib.Path) @@ -255,7 +260,7 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O if mode == "r": try: - path = path.open("rb") + return open_path(path.open("rb"), mode, verbose, suffix) except FileNotFoundError as error: if suffix is not None and suffix != "": newpath = pathlib.Path(f"{path}.{suffix}") @@ -270,7 +275,7 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O path = pathlib.Path(f"{path}.{suffix}") if path.exists() and path.is_file() and verbose >= 2: warnings.warn(f"Path '{path}' exists, will overwrite it.") - path = path.open("wb") + return open_path(path.open("wb"), mode, verbose, suffix) except IsADirectoryError: warnings.warn(f"Path '{path}' is a folder. Will save instead to {path}_2") path = pathlib.Path(f"{path}_2") @@ -278,12 +283,11 @@ def open_path_pathlib(path: pathlib.Path, mode: str, verbose: int = 0, suffix: O warnings.warn(f"Path '{path.parent}' does not exist. Will create it.") path.parent.mkdir(exist_ok=True, parents=True) - # if opening was successful uses the identity function + # if opening was successful uses the open_path() function # if opening failed with IsADirectory|FileNotFound, calls open_path_pathlib # with corrections # if reading failed with FileNotFoundError, calls open_path_pathlib with suffix - - return open_path(path, mode, verbose, suffix) + return open_path_pathlib(path, mode, verbose, suffix) def save_to_zip_file( diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index b6fbe59be..4950822d4 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -19,7 +19,7 @@ try: from torch.utils.tensorboard import SummaryWriter except ImportError: - SummaryWriter = None + SummaryWriter = None # type: ignore[misc, assignment] from stable_baselines3.common.logger import Logger, configure from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit @@ -396,13 +396,13 @@ def is_vectorized_observation(observation: Union[int, np.ndarray], observation_s for space_type, is_vec_obs_func in is_vec_obs_func_dict.items(): if isinstance(observation_space, space_type): - return is_vec_obs_func(observation, observation_space) + return is_vec_obs_func(observation, observation_space) # type: ignore[operator] else: # for-else happens if no break is called raise ValueError(f"Error: Cannot determine if the observation is vectorized with the space type {observation_space}.") -def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray: +def safe_mean(arr: Union[np.ndarray, list, deque]) -> float: """ Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only. @@ -410,7 +410,7 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray: :param arr: Numpy array or list of values :return: """ - return np.nan if len(arr) == 0 else np.mean(arr) + return np.nan if len(arr) == 0 else float(np.mean(arr)) # type: ignore[arg-type] def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]: diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index beb603961..487bd8c07 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -20,7 +20,7 @@ class VecTransposeImage(VecEnvWrapper): def __init__(self, venv: VecEnv, skip: bool = False): assert is_image_space(venv.observation_space) or isinstance( - venv.observation_space, spaces.dict.Dict + venv.observation_space, spaces.Dict ), "The observation space must be an image or dictionary observation space" self.skip = skip @@ -29,16 +29,18 @@ def __init__(self, venv: VecEnv, skip: bool = False): super().__init__(venv) return - if isinstance(venv.observation_space, spaces.dict.Dict): + if isinstance(venv.observation_space, spaces.Dict): self.image_space_keys = [] observation_space = deepcopy(venv.observation_space) for key, space in observation_space.spaces.items(): if is_image_space(space): # Keep track of which keys should be transposed later self.image_space_keys.append(key) + assert isinstance(space, spaces.Box) observation_space.spaces[key] = self.transpose_space(space, key) else: - observation_space = self.transpose_space(venv.observation_space) + assert isinstance(venv.observation_space, spaces.Box) + observation_space = self.transpose_space(venv.observation_space) # type: ignore[assignment] super().__init__(venv, observation_space=observation_space) @staticmethod @@ -57,7 +59,7 @@ def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: ), f"The observation space {key} must follow the channel last convention" height, width, channels = observation_space.shape new_shape = (channels, height, width) - return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) + return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) # type: ignore[arg-type] @staticmethod def transpose_image(image: np.ndarray) -> np.ndarray: @@ -101,13 +103,16 @@ def step_wait(self) -> VecEnvStepReturn: if "terminal_observation" in infos[idx]: infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) + assert isinstance(observations, (np.ndarray, dict)) return self.transpose_observations(observations), rewards, dones, infos def reset(self) -> Union[np.ndarray, Dict]: """ Reset all environments """ - return self.transpose_observations(self.venv.reset()) + observations = self.venv.reset() + assert isinstance(observations, (np.ndarray, dict)) + return self.transpose_observations(observations) def close(self) -> None: self.venv.close() diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 6f670054b..52faebd1f 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -22,6 +22,8 @@ class VecVideoRecorder(VecEnvWrapper): :param name_prefix: Prefix to the video name """ + video_recorder: video_recorder.VideoRecorder + def __init__( self, venv: VecEnv, @@ -50,8 +52,6 @@ def __init__( assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" self.record_video_trigger = record_video_trigger - self.video_recorder = None - self.video_folder = os.path.abspath(video_folder) # Create output folder if needed os.makedirs(self.video_folder, exist_ok=True) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 7ec1d6db4..887948350 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.1.0 +2.2.0a0 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9d0aa44e9..39b141d2a 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -675,8 +675,8 @@ def test_open_file(tmp_path): buff = io.BytesIO() assert buff.writable() assert buff.readable() is ("w" == "w") - _ = open_path(buff, "w") - assert _ is buff + opened_buffer = open_path(buff, "w") + assert opened_buffer is buff with pytest.raises(ValueError): buff.close() open_path(buff, "w")