diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index bb1f7b04..f9be4898 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -15,6 +15,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``progress_bar`` argument in the ``learn()`` method, displayed using TQDM and rich packages Bug Fixes: ^^^^^^^^^^ diff --git a/sb3_contrib/ars/ars.py b/sb3_contrib/ars/ars.py index 241a814b..f5ac48aa 100644 --- a/sb3_contrib/ars/ars.py +++ b/sb3_contrib/ars/ars.py @@ -315,6 +315,7 @@ def learn( eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, async_eval: Optional[AsyncEval] = None, + progress_bar: bool = False, ) -> ARSSelf: """ Return a trained model. @@ -333,11 +334,20 @@ def learn( :param eval_log_path: Path to a folder where the evaluations will be saved :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param async_eval: The object for asynchronous evaluation of candidates. + :param progress_bar: Display a progress bar using tqdm and rich. :return: the trained model """ total_steps, callback = self._setup_learn( - total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + eval_log_path, + reset_num_timesteps, + tb_log_name, + progress_bar, ) callback.on_training_start(locals(), globals()) diff --git a/sb3_contrib/ppo_mask/ppo_mask.py b/sb3_contrib/ppo_mask/ppo_mask.py index 0cedc4c9..62f92f62 100644 --- a/sb3_contrib/ppo_mask/ppo_mask.py +++ b/sb3_contrib/ppo_mask/ppo_mask.py @@ -10,7 +10,7 @@ from gym import spaces from stable_baselines3.common import utils from stable_baselines3.common.buffers import RolloutBuffer -from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, ProgressBarCallback from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -184,6 +184,7 @@ def _init_callback( n_eval_episodes: int = 5, log_path: Optional[str] = None, use_masking: bool = True, + progress_bar: bool = False, ) -> BaseCallback: """ :param callback: Callback(s) called at every step with state of the algorithm. @@ -196,6 +197,7 @@ def _init_callback( :param n_eval_episodes: How many episodes to play per evaluation :param log_path: Path to a folder where the evaluations will be saved :param use_masking: Whether or not to use invalid action masks during evaluation + :param progress_bar: Display a progress bar using tqdm and rich. :return: A hybrid callback calling `callback` and performing evaluation. """ # Convert a list of callbacks into a callback @@ -206,6 +208,10 @@ def _init_callback( if not isinstance(callback, BaseCallback): callback = ConvertCallback(callback) + # Add progress bar callback + if progress_bar: + callback = CallbackList([callback, ProgressBarCallback()]) + # Create eval callback in charge of the evaluation if eval_env is not None: # Avoid circular import error @@ -236,6 +242,7 @@ def _setup_learn( reset_num_timesteps: bool = True, tb_log_name: str = "run", use_masking: bool = True, + progress_bar: bool = False, ) -> Tuple[int, BaseCallback]: """ Initialize different variables needed for training. @@ -253,6 +260,7 @@ def _setup_learn( :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute :param tb_log_name: the name of the run for tensorboard log :param use_masking: Whether or not to use invalid action masks during training + :param progress_bar: Display a progress bar using tqdm and rich. :return: """ @@ -299,7 +307,7 @@ def _setup_learn( self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps) # Create eval callback if needed - callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking) + callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, use_masking, progress_bar) return total_timesteps, callback @@ -563,6 +571,7 @@ def learn( eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, use_masking: bool = True, + progress_bar: bool = False, ) -> MaskablePPOSelf: iteration = 0 @@ -576,6 +585,7 @@ def learn( reset_num_timesteps, tb_log_name, use_masking, + progress_bar, ) callback.on_training_start(locals(), globals()) diff --git a/sb3_contrib/ppo_recurrent/ppo_recurrent.py b/sb3_contrib/ppo_recurrent/ppo_recurrent.py index 68d501fe..965e0080 100644 --- a/sb3_contrib/ppo_recurrent/ppo_recurrent.py +++ b/sb3_contrib/ppo_recurrent/ppo_recurrent.py @@ -1,7 +1,7 @@ import sys import time from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import gym import numpy as np @@ -198,47 +198,6 @@ def _setup_model(self) -> None: self.clip_range_vf = get_schedule_fn(self.clip_range_vf) - def _setup_learn( - self, - total_timesteps: int, - eval_env: Optional[GymEnv], - callback: MaybeCallback = None, - eval_freq: int = 10000, - n_eval_episodes: int = 5, - log_path: Optional[str] = None, - reset_num_timesteps: bool = True, - tb_log_name: str = "RecurrentPPO", - ) -> Tuple[int, BaseCallback]: - """ - Initialize different variables needed for training. - - :param total_timesteps: The total number of samples (env steps) to train on - :param eval_env: Environment to use for evaluation. - Caution, this parameter is deprecated and will be removed in the future. - Please use `EvalCallback` or a custom Callback instead. - :param callback: Callback(s) called at every step with state of the algorithm. - :param eval_freq: Evaluate the agent every ``eval_freq`` timesteps (this may vary a little). - Caution, this parameter is deprecated and will be removed in the future. - Please use `EvalCallback` or a custom Callback instead. - :param n_eval_episodes: How many episodes to play per evaluation - :param log_path: Path to a folder where the evaluations will be saved - :param reset_num_timesteps: Whether to reset or not the ``num_timesteps`` attribute - :param tb_log_name: the name of the run for tensorboard log - :return: - """ - - total_timesteps, callback = super()._setup_learn( - total_timesteps, - eval_env, - callback, - eval_freq, - n_eval_episodes, - log_path, - reset_num_timesteps, - tb_log_name, - ) - return total_timesteps, callback - def collect_rollouts( self, env: VecEnv, @@ -500,11 +459,20 @@ def learn( tb_log_name: str = "RecurrentPPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> RecurrentPPOSelf: iteration = 0 total_timesteps, callback = self._setup_learn( - total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + eval_log_path, + reset_num_timesteps, + tb_log_name, + progress_bar, ) callback.on_training_start(locals(), globals()) diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 38f32ff1..f521219c 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -262,6 +262,7 @@ def learn( tb_log_name: str = "QRDQN", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> QRDQNSelf: return super().learn( @@ -274,6 +275,7 @@ def learn( tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) def _excluded_save_params(self) -> List[str]: diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index b38b53c6..df65496a 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -299,6 +299,7 @@ def learn( tb_log_name: str = "TQC", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> TQCSelf: return super().learn( @@ -311,6 +312,7 @@ def learn( tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) def _excluded_save_params(self) -> List[str]: diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index a66d9e12..7d931308 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -415,6 +415,7 @@ def learn( tb_log_name: str = "TRPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, + progress_bar: bool = False, ) -> TRPOSelf: return super().learn( @@ -427,4 +428,5 @@ def learn( tb_log_name=tb_log_name, eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps, + progress_bar=progress_bar, ) diff --git a/tests/test_invalid_actions.py b/tests/test_invalid_actions.py index 35099148..30f2cff5 100644 --- a/tests/test_invalid_actions.py +++ b/tests/test_invalid_actions.py @@ -190,7 +190,7 @@ def test_callback(tmp_path): model = MaskablePPO("MlpPolicy", env, n_steps=64, gamma=0.4, seed=32, verbose=1) model.learn(100, callback=MaskableEvalCallback(eval_env, eval_freq=100, warn=False, log_path=tmp_path)) - model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False)) + model.learn(100, callback=MaskableEvalCallback(Monitor(eval_env), eval_freq=100, warn=False), progress_bar=True) def test_child_callback(): diff --git a/tests/test_run.py b/tests/test_run.py index c409e4c4..09a00904 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -20,7 +20,7 @@ def test_tqc(ent_coef): create_eval_env=True, ent_coef=ent_coef, ) - model.learn(total_timesteps=300, eval_freq=250) + model.learn(total_timesteps=300, eval_freq=250, progress_bar=True) @pytest.mark.parametrize("n_critics", [1, 3])