xpag ("exploring agents") is a modular reinforcement learning library with JAX agents, currently in beta version.
Option 1: conda (preferred option)
This option is preferred because it relies mainly on conda-forge packages (which among other things simplifies the installation of JAX).
git clone https://github.com/perrin-isir/xpag.git
cd xpag
conda update conda
Install micromamba if you don't already have it (you can also simply use conda, by replacing below micromamba create
, micromamba update
and micromamba activate
respectively by conda env create
, conda env update
and conda activate
, but this will lead to a significantly slower installation):
conda install -c conda-forge micromamba
Choose an environment name, for instance xpagenv
.
The following command creates the xpagenv
environment with the requirements listed in environment.yaml:
micromamba create --name xpagenv --file environment.yaml
If you prefer to update an existing environment (existing_env
):
micromamba update --name existing_env --file environment.yaml
Then, activate the xpagenv
environment:
micromamba activate xpagenv
Finally, install the xpag library in the activated environment:
pip install -e .
Option 2: pip
For the pip install, you need to properly install JAX yourself. Otherwise, if JAX is installed automatically as a pip dependency of xpag, it will probably not work as desired (e.g. it will not be GPU-compatible). So you should install it beforehand, following these guidelines:
https://github.com/google/jax#installation
Then, install xpag with:
pip install xpag
JAX
To verify that the JAX installation went well, check the backend used by JAX with the following command:
python -c "import jax; print(jax.lib.xla_bridge.get_backend().platform)"
It will print "cpu", "gpu" or "tpu" depending on the platform JAX is using.
Tutorials
The following libraries, not required by xpag, are required for the tutorials:
- MuJoCo (
pip install mujoco
): see https://github.com/deepmind/mujoco - imageio (
pip install imageio
): see https://github.com/imageio/imageio
The xpag-tutorials repository contains a list of tutorials (colab notebooks) for xpag:
https://github.com/perrin-isir/xpag-tutorials
xpag: a platform for RL, goal-conditioned RL, and more.
xpag allows standard reinforcement learning, but it has been designed with goal-conditioned reinforcement learning (GCRL) in mind (check out the train_gmazes.ipynb tutorial for a simple example of GCRL).
In GCRL, agents have a goal, which is part of the input they take, and the reward mainly depends on the degree of achievement of that goal. Beyond the usual modules in RL platforms (environment, agent, buffer/sampler), xpag introduces a module called "setter" which, among other things, can help to set and manage goals (for example modifying the goal several times in a single episode). Although the setter is largely similar to an environment wrapper, it is separated from the environment because in some cases it should be considered as an independent entity (e.g. a teacher), or as a part of the agent itself.
xpag relies on a single reinforcement learning loop (the learn()
function in xpag/tools/learn.py)
in which the environment, the agent, the buffer and the setter interact (see below).
The learn()
function has the following first 3 arguments (returned by gym_vec_env() and
brax_vec_env()):
env
: the training environment, which runs 1 or more rollouts in parallel.eval_env
: the evaluation environment, identical toenv
except that it runs a single rollout.env_info
: a dictionary containing information about the environment:env_info["env_type"]
: the type of environment; for the moment xpag differentiates 3 types of environments: "Brax" environments, "Mujoco" environments, and "Gym" environments. This information is used to adapt the way episodes are saved and replayed.env_info["name"]
: the name of the environment.env_info["is_goalenv"]
: whether the environment is a goal-based environment or not.env_info["num_envs"]
: the number of parallel rollouts inenv
env_info["max_episode_steps"]
: the maximum number of steps in episodes (xpag does not allow potentially infinite episodes).env_info["action_space"]
: the action space (of type gym.spaces.Space) that takes into account parallel rollouts. It can be useful to sample random actions.env_info["single_action_space"]
: the action space (of type gym.spaces.Space) for single rollouts.
learn()
also takes in input the agent, the buffer and the setter and various parameters. Detailed information about the arguments of learn()
can be
found in the code documentation (check xpag/tools/learn.py).
The components that interact during learning are:
the environment (env)
In xpag, environments must allow parallel rollouts, and xpag keeps the same API even in the case of a single rollout, i.e. when the number of "parallel environments" is 1. Basically, all environments are "vector environments".
-
env.reset(seed: Optional[Union[int, List[int]]], options: Optional[dict])
->observation: Union[np.array, jax.numpy.array], info: dict
Following the gym Vector API (see https://www.gymlibrary.dev/api/vector/#vectorenv), environments have areset()
function that returns anobservation
(which is actually a batch of observations for all the parallel rollouts) and an optional dictionaryinfo
(see https://www.gymlibrary.dev/api/vector/#reset).
We expectobservation
to be a numpy array, or a jax.numpy array, and its first dimension selects between parallel rollouts, which means thatobservation[i]
is the observation in the i-th rollout. In the case of a single rollout,observation[0]
is the observation in this rollout. -
env.step(action: Union[np.array, jax.numpy.array])
->observation, reward, terminated, truncated, info
Again, following the gym Vector API, environments have astep()
function that takes in input an action (which is actually a batch of actions, one per rollout) and returns:observation
,reward
,terminated
,truncated
,info
(cf. https://www.gymlibrary.dev/api/vector/#step). There are slight differences with the gym Vector API. First, in xpag this API also covers the case of a single rollout. Second, xpag assumes thatreward
,terminated
andtruncated
have shape(n, 1)
, not(n,)
(where n is the number of parallel rollouts). More broadly, whether they are due to a single rollout or to unidimensional elements, single-dimensional entries are not squeezed in xpag. Third, in xpag,info
is a dictionary, not a tuple of dictionaries (however its entries may be tuples). -
env.reset_done(done, seed: Optional[Union[int, List[int]]], options: Optional[dict])
->observation, info
The most significant difference with the gym Vector API is that xpag requires areset_done()
function which takes adone
array of Booleans in input and performs a reset for the i-th rollout if and only ifdone[i]
is evaluated to True. Besidesdone
, the arguments ofreset_done()
are the same as the ones ofreset()
:seed
andoptions
, and its outputs are also the same:observation
,info
. For rollouts that are not reset, the returned observation is the same as the observation returned by the laststep()
.reset()
must be called once for the initial reset, and afterwards onlyreset_done()
should be used. Auto-resets (automatic resets after terminal transitions) are not allowed in xpag. The main reason to preferreset_done()
to auto-resets is that with auto-resets, terminal transitions must be special and contain additional information. Withreset_done()
, this is no longer necessary. Furthermore, by modifying thedone
array returned by a step of the environment, it becomes possible to easily force the termination of an episode, or to force an episode to continue despite reaching a terminal transition (but this must be done with caution). -
gym_vec_env(env_name: str, num_envs: int, wrap_function: Callable = None)
->env, eval_env, env_info: dict
brax_vec_env(env_name: str, num_envs: int, wrap_function: Callable = None, *, force_cpu_backend : bool = False)
->env, eval_env, env_info: dict
The gym_vec_env() and brax_vec_env() functions (see tutorials) call wrappers that automatically add thereset_done()
function to Gym and Brax environments, and make the wrapped environments fit the xpag API. -
Goal-based environments:
Goal-based environments (for GCRL) must have a similar interface to the one defined in the Gym-Robotics library (seeGoalEnv
in core.py), with minor differences. Their observation spaces are of type gym.spaces.Dict, with the following keys in theobservation
dictionaries:"observation"
,"achieved_goal"
, and"desired_goal"
. Goal-based environments must also have in attribute acompute_reward()
function that computes rewards. In xpag, the inputs ofcompute_reward()
can be different from the ones considered in the originalGoalEnv
class. For example, in the GoalEnvWrapper class, which can be used to turn standard environments into goal-based environments, the arguments ofcompute_reward()
are assumed to beachieved_goal
(the goal achieved afterstep()
),desired_goal
(the desired goal beforestep()
),action
,observation
(the observation afterstep()
),reward
(the reward of the base environment),terminated
,truncated
andinfo
(the outputs of thestep()
function). In the version of HER (cf. https://arxiv.org/pdf/1707.01495.pdf) in xpag, it is assumed thatcompute_reward()
depends only onachieved_goal
,desired_goal
,action
andobservation
.
In goal-based environments, the multiple observations from parallel rollouts are concatenated as in the gym functionconcatenate()
(cf. https://github.com/openai/gym/blob/master/gym/vector/utils/numpy_utils.py), which means that the batched observations are always single dictionaries in which the entries"observation"
,"achieved_goal"
and"desired_goal"
are arrays of observations, achieved goals and desired goals. -
info
xpag assumes that, in goal-based environments, theinfo
dictionary returned bystep()
always containsinfo["is_success"]
, an array of Booleans (one per rollout) that areTrue
if the corresponding transition is a successfull achievement of the desired goal, andFalse
otherwise (remark: this does not need to coincide with episode termination).
the agent (agent)
xpag only considers off-policy agents. (TODO)
the buffer (buffer)
TODOthe sampler (sampler)
TODOthe setter (setter)
TODOThe figure below summarizes the RL loop and the interactions between the components: (TODO)
-
Maintainer and main contributor:
- Nicolas Perrin-Gilbert (CNRS, ISIR)
Other people who contributed to xpag:
- Olivier Serris (ISIR)
- Alexandre Chenu (ISIR)
- Stéphane Caron (Inria)
- Fabian Schramm (Inria)
-
There is an interface to agents from the RLJAX library (see rljax_interface.py). This provides haiku versions of DDPG, TD3, TQC, SAC and SAC with DisCor.
-
The flax version of the SAC agent is based on the implementation of SAC in JAXRL, and some elements of the flax version of the TQC agent come from the implementation of TQC in RLJAX.
To cite this repository in publications:
@misc{xpag,
author = {Perrin-Gilbert, Nicolas},
title = {xpag: a modular reinforcement learning library with JAX agents},
year = {2022},
url = {https://github.com/perrin-isir/xpag}
}