forked from araffin/sbx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
88 lines (74 loc) · 2.9 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
from setuptools import find_packages, setup
with open(os.path.join("sbx", "version.txt")) as file_handler:
__version__ = file_handler.read().strip()
long_description = """
# Stable Baselines Jax (SB3 + JAX = SBX)
See https://github.com/araffin/sbx
Proof of concept version of [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) in Jax.
Implemented algorithms:
- [Soft Actor-Critic (SAC)](https://arxiv.org/abs/1801.01290) and [SAC-N](https://arxiv.org/abs/2110.01548)
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Dropout Q-Functions for Doubly Efficient Reinforcement Learning (DroQ)](https://openreview.net/forum?id=xCVJMsPv3RT)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347)
- [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)
## Example
```python
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
model = TQC("MlpPolicy", "Pendulum-v1", verbose=1)
model.learn(total_timesteps=10_000, progress_bar=True)
"""
setup(
name="sbx-rl",
packages=[package for package in find_packages() if package.startswith("sbx")],
package_data={"sbx": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.4.0a4,<3.0",
"jax",
"jaxlib",
"flax",
'optax; python_version >= "3.9.0"',
# See https://github.com/google-deepmind/optax/issues/711
'optax<0.1.8; python_version < "3.9.0"',
"tqdm",
"rich",
"tensorflow_probability",
],
extras_require={
"tests": [
# Run tests and coverage
"pytest",
"pytest-cov",
"pytest-env",
"pytest-xdist",
# Type check
"mypy",
# Lint code
"ruff>=0.3.1",
# Reformat
"black>=24.2.0,<25",
],
},
description="Jax version of Stable Baselines, implementations of reinforcement learning algorithms.",
author="Antonin Raffin",
url="https://github.com/araffin/sbx",
author_email="[email protected]",
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
"gym openai stable baselines toolbox python data-science",
license="MIT",
long_description=long_description,
long_description_content_type="text/markdown",
version=__version__,
python_requires=">=3.8",
# PyPI package information.
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
)