Skip to content

Commit

Permalink
Merge branch 'DLR-RM:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
RaikoPipe authored Nov 11, 2024
2 parents 7aac265 + e4f4f12 commit 3fe7c0d
Show file tree
Hide file tree
Showing 25 changed files with 206 additions and 146 deletions.
27 changes: 17 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -31,18 +36,20 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu
pip install .[extra_no_roms,tests,docs]
uv pip install --system .[extra,tests,docs]
# Use headless version
pip install opencv-python-headless
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
make lint
Expand Down
4 changes: 2 additions & 2 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ conda:
environment: docs/conda_env.yml

build:
os: ubuntu-22.04
os: ubuntu-24.04
tools:
python: "mambaforge-22.9"
python: "mambaforge-23.11"
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ into two categories:
- Create an issue about your intended feature, and we shall discuss the design and
implementation. Once we agree that the plan looks good, go ahead and implement it.
2. You want to implement a feature or bug-fix for an outstanding issue
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues
- Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted
- Pick an issue or feature and comment on the task that you want to work on this feature.
- If you need more context on a particular issue, please ask, and we shall provide.

Expand Down
32 changes: 21 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<!-- [![pipeline status](https://gitlab.com/araffin/stable-baselines3/badges/master/pipeline.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) -->
![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master)
[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml)
[![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)


Expand All @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to
**The performance of each algorithm was tested** (see *Results* section in their respective page),
you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details.

We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform.


| **Features** | **Stable-Baselines3** |
| --------------------------- | ----------------------|
Expand All @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin

### Planned features

Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).
Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*.
If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement).

While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories:
- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository
- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository
- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299)

## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)

Expand Down Expand Up @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/

We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).
This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO).

Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)

Expand All @@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster
### Prerequisites
Stable Baselines3 requires Python 3.8+.

#### Windows 10
#### Windows

To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites).


### Install using pip
Install the Stable Baselines3 package:
```sh
pip install 'stable-baselines3[extra]'
```
pip install stable-baselines3[extra]
```
**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)).

This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use:
```sh
Expand Down Expand Up @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks:
| ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- |
| ARS<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| CrossQ<sup>[1](#f1)</sup> | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: |
| DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
| HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: |
Expand All @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks:

<b id="f1">1</b>: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository.

Actions `gym.spaces`:
Actions `gymnasium.spaces`:
* `Box`: A N-dimensional box that contains every point in the action space.
* `Discrete`: A list of possible actions, where each timestep only one of the actions can be used.
* `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used.
Expand All @@ -218,9 +226,9 @@ To run a single test:
python3 -m pytest -v -k 'test_check_env_dict_action'
```

You can also do a static type check using `pytype` and `mypy`:
You can also do a static type check using `mypy`:
```sh
pip install pytype mypy
pip install mypy
make type
```

Expand Down Expand Up @@ -252,6 +260,8 @@ To cite this repository in publications:
}
```

Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988).

## Maintainers

Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec).
Expand Down
14 changes: 7 additions & 7 deletions docs/conda_env.yml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
name: root
channels:
- pytorch
- defaults
- conda-forge
dependencies:
- cpuonly=1.0=0
- pip=22.3.1
- python=3.8
- pytorch=1.13.0=py3.8_cpu_0
- pip=24.2
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium
- gymnasium>=0.29.1,<1.1.0
- cloudpickle
- opencv-python-headless
- pandas
- numpy
- numpy>=1.20,<2.0
- matplotlib
- sphinx>=5,<8
- sphinx>=5,<9
- sphinx_rtd_theme>=1.3.0
- sphinx_copybutton
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary``
=================== =========== ============ ================= =============== ================
ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️
A2C ✔️ ✔️ ✔️ ✔️ ✔️
CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️
DDPG ✔️ ❌ ❌ ❌ ✔️
DQN ❌ ✔️ ❌ ❌ ✔️
HER ✔️ ✔️ ❌ ❌ ✔️
Expand Down
1 change: 1 addition & 0 deletions docs/guide/sb3_contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ See documentation for the full list of included features.
- `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) <https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/>`_
- `Truncated Quantile Critics (TQC)`_
- `Trust Region Policy Optimization (TRPO) <https://arxiv.org/abs/1502.05477>`_
- `Batch Normalization in Deep Reinforcement Learning (CrossQ) <https://openreview.net/forum?id=PczQtTsTIX>`_


**Gym Wrappers**:
Expand Down
4 changes: 3 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ To cite this project in publications:
url = {http://jmlr.org/papers/v22/20-1364.html}
}
Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI <https://doi.org/10.5281/zenodo.8123988>`_.

Contributing
------------

To any interested in making the rl baselines better, there are still some improvements
that need to be done.
You can check issues in the `repo <https://github.com/DLR-RM/stable-baselines3/issues>`_.
You can check issues in the `repository <https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted>`_.

If you want to contribute, please read `CONTRIBUTING.md <https://github.com/DLR-RM/stable-baselines3/blob/master/CONTRIBUTING.md>`_ first.

Expand Down
16 changes: 15 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
Changelog
==========

Release 2.4.0a10 (WIP)
Release 2.4.0a11 (WIP)
--------------------------

**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**

.. note::

DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about
Expand All @@ -22,12 +24,14 @@ Release 2.4.0a10 (WIP)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increase minimum required version of Gymnasium to 0.29.1

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
- Added support for Gymnasium v1.0

Bug Fixes:
^^^^^^^^^^
Expand All @@ -43,6 +47,10 @@ Bug Fixes:

`SB3-Contrib`_
^^^^^^^^^^^^^^
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
- Fixed loading QRDQN changes `target_update_interval` (@jak3122)

`RL Zoo`_
^^^^^^^^^
Expand All @@ -51,6 +59,7 @@ Bug Fixes:
`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added CNN support for DQN
- Bug fix for SAC and related algorithms, optimize log of ent coeff to be consistent with SB3

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -61,13 +70,18 @@ Others:
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
- Switched to uv to download packages faster on GitHub CI
- Updated dependencies for read the doc
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``

Bug Fixes:
^^^^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Updated PPO doc to recommend using CPU with ``MlpPolicy``
- Clarified documentation about planned features and citing software
- Added a note about the fact we are optimizing log of ent coeff for SAC

Release 2.3.2 (2024-04-27)
--------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Notes

- Original paper: https://arxiv.org/abs/1312.5602
- Further reference: https://www.nature.com/articles/nature14236
- Tutorial "From Tabular Q-Learning to DQN": https://github.com/araffin/rlss23-dqn-tutorial

.. note::
This implementation provides only vanilla Deep Q-Learning and has no extensions such as Double-DQN, Dueling-DQN and Prioritized Experience Replay.
Expand Down
3 changes: 3 additions & 0 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Notes
which is the equivalent to the inverse of reward scale in the original SAC paper.
The main reason is that it avoids having too high errors when updating the Q functions.

.. note::
When automatically adjusting the temperature (alpha/entropy coefficient), we optimize the logarithm of the entropy coefficient instead of the entropy coefficient itself. This is consistent with the original implementation and has proven to be more stable
(see issues `GH#36 <https://github.com/DLR-RM/stable-baselines3/issues/36>`_, `#55 <https://github.com/araffin/sbx/issues/55>`_ and others).

.. note::

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py" = ["RUF012", "RUF013"]


[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15
Expand Down
45 changes: 17 additions & 28 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,13 @@
""" # noqa:E501

# Atari Games download is sometimes problematic:
# https://github.com/Farama-Foundation/AutoROM/issues/39
# That's why we define extra packages without it.
extra_no_roms = [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"shimmy[atari]~=1.3.0",
"pillow",
]

extra_packages = extra_no_roms + [ # noqa: RUF005
# For atari roms,
"autorom[accept-rom-license]~=0.6.1",
]


setup(
name="stable_baselines3",
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.28.1,<0.30",
"gymnasium>=0.29.1,<1.1.0",
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
"torch>=1.13",
# For saving models
Expand All @@ -125,16 +101,29 @@
"black>=24.2.0,<25",
],
"docs": [
"sphinx>=5,<8",
"sphinx>=5,<9",
"sphinx-autobuild",
"sphinx-rtd-theme>=1.3.0",
# For spelling
"sphinxcontrib.spelling",
# Copy button for code snippets
"sphinx_copybutton",
],
"extra": extra_packages,
"extra_no_roms": extra_no_roms,
"extra": [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"ale-py>=0.9.0",
"pillow",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
author="Antonin Raffin",
Expand Down
Loading

0 comments on commit 3fe7c0d

Please sign in to comment.