Skip to content

Commit

Permalink
Merge branch 'main' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Oct 6, 2022
2 parents 82262e5 + 2fb5619 commit ab4d4ca
Show file tree
Hide file tree
Showing 25 changed files with 194 additions and 92 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ QDax has been developed as a research framework: it is flexible and easy to exte


## Installation

The latest stable release of QDax can be installed directly from source with:
QDax is available on PyPI and can be installed with:
```bash
pip install qdax
```
Alternatively, the latest commit of QDax can be installed directly from source with:
```bash
pip install git+https://github.com/adaptive-intelligent-robotics/QDax.git@main
```
Installing QDax via ```pip``` installs a CPU-only version of JAX by default. To use QDax with NVidia GPUs, you must first install [CUDA, CuDNN, and JAX with GPU support](https://github.com/google/jax#installation).

However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/).
However, we also provide and recommend using either Docker, Singularity or conda environments to use the repository which by default provides GPU support. Detailed steps to do so are available in the [documentation](https://qdax.readthedocs.io/en/latest/installation/).

## Basic API Usage
For a full and interactive example to see how QDax works, we recommend starting with the tutorial-style [Colab notebook](./notebooks/mapelites_example.ipynb). It is an example of the MAP-Elites algorithm used to evolve a population of controllers on a chosen Brax environment (Walker by default).
Expand Down
3 changes: 1 addition & 2 deletions dev.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ COPY requirements.txt /tmp/requirements.txt
COPY requirements-dev.txt /tmp/requirements-dev.txt
COPY environment.yaml /tmp/environment.yaml


RUN micromamba create -y --file /tmp/environment.yaml \
&& micromamba clean --all --yes \
&& find /opt/conda/ -follow -type f -name '*.pyc' -delete
Expand Down Expand Up @@ -41,7 +40,7 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.0/targets/x86_64-linux/l

ENV TZ=Europe/Paris
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN pip --no-cache-dir install jaxlib==0.3.10+cuda11.cudnn82 \
RUN pip --no-cache-dir install jaxlib==0.3.15+cuda11.cudnn82 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
&& rm -rf /tmp/*

Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ dependencies:
- conda>=4.9.2
- pip:
- --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
- jaxlib==0.3.10
- jaxlib==0.3.15
- -r requirements.txt
- -r requirements-dev.txt
4 changes: 2 additions & 2 deletions notebooks/cmamega_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@
" gradients = jnp.nan_to_num(gradients)\n",
"\n",
" # Compute normalized gradients\n",
" norm_gradients = jax.tree_map(\n",
" norm_gradients = jax.tree_util.tree_map(\n",
" lambda x: jnp.linalg.norm(x, axis=1, keepdims=True),\n",
" gradients,\n",
" )\n",
" grads = jax.tree_map(\n",
" grads = jax.tree_util.tree_map(\n",
" lambda x, y: x / y, gradients, norm_gradients\n",
" )\n",
" grads = jnp.nan_to_num(grads)\n",
Expand Down
10 changes: 6 additions & 4 deletions qdax/baselines/dads.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
update_running_mean_std,
)
from qdax.core.neuroevolution.sac_utils import generate_unroll
from qdax.environments import CompletedEvalWrapper
from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor


Expand Down Expand Up @@ -144,7 +145,7 @@ def init( # type: ignore
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)

Expand Down Expand Up @@ -373,16 +374,17 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)

true_returns = jnp.nansum(transitions.rewards, axis=0)

reshaped_transitions = jax.tree_map(
reshaped_transitions = jax.tree_util.tree_map(
lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
transitions,
)
Expand Down
10 changes: 6 additions & 4 deletions qdax/baselines/diayn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from qdax.core.neuroevolution.mdp_utils import TrainingState, get_first_episode
from qdax.core.neuroevolution.networks.diayn_networks import make_diayn_networks
from qdax.core.neuroevolution.sac_utils import generate_unroll
from qdax.environments import CompletedEvalWrapper
from qdax.types import Metrics, Params, Reward, RNGKey, Skill, StateDescriptor


Expand Down Expand Up @@ -141,7 +142,7 @@ def init( # type: ignore
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)

Expand Down Expand Up @@ -316,16 +317,17 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)

true_return_per_env = jnp.nansum(transitions.rewards, axis=0)

reshaped_transitions = jax.tree_map(
reshaped_transitions = jax.tree_util.tree_map(
lambda x: x.reshape((self._config.episode_length * env_batch_size, -1)),
transitions,
)
Expand Down
10 changes: 6 additions & 4 deletions qdax/baselines/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
update_running_mean_std,
)
from qdax.core.neuroevolution.sac_utils import generate_unroll
from qdax.environments import CompletedEvalWrapper
from qdax.types import Action, Metrics, Observation, Params, Reward, RNGKey


Expand Down Expand Up @@ -115,7 +116,7 @@ def init(
random_key, subkey = jax.random.split(random_key)
critic_params = self._critic.init(subkey, dummy_obs, dummy_action)

target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)

Expand Down Expand Up @@ -298,9 +299,10 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)
Expand Down Expand Up @@ -389,7 +391,7 @@ def _update_critic(
critic_params = optax.apply_updates(
training_state.critic_params, critic_updates
)
target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.tau) * x1 + self._config.tau * x2,
training_state.target_critic_params,
critic_params,
Expand Down
14 changes: 8 additions & 6 deletions qdax/baselines/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
get_first_episode,
)
from qdax.core.neuroevolution.networks.td3_networks import make_td3_networks
from qdax.environments import CompletedEvalWrapper
from qdax.types import Action, Metrics, Observation, Params, Reward, RNGKey


Expand Down Expand Up @@ -113,10 +114,10 @@ def init(
policy_params = self._policy.init(subkey_2, fake_obs)

# Initialize target networks
target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), critic_params
)
target_policy_params = jax.tree_map(
target_policy_params = jax.tree_util.tree_map(
lambda x: jnp.asarray(x.copy()), policy_params
)

Expand Down Expand Up @@ -251,9 +252,10 @@ def eval_policy_fn(
play_step_fn=play_step_fn,
)

eval_metrics_key = CompletedEvalWrapper.STATE_INFO_KEY
true_return = (
state.info["eval_metrics"].completed_episodes_metrics["reward"]
/ state.info["eval_metrics"].completed_episodes
state.info[eval_metrics_key].completed_episodes_metrics["reward"]
/ state.info[eval_metrics_key].completed_episodes
)

transitions = get_first_episode(transitions)
Expand Down Expand Up @@ -303,7 +305,7 @@ def update(
training_state.critic_params, critic_updates
)
# Soft update of target critic network
target_critic_params = jax.tree_map(
target_critic_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_critic_params,
Expand All @@ -325,7 +327,7 @@ def update_policy_step() -> Tuple[Params, Params, optax.OptState]:
training_state.policy_params, policy_updates
)
# Soft update of target policy
target_policy_params = jax.tree_map(
target_policy_params = jax.tree_util.tree_map(
lambda x1, x2: (1.0 - self._config.soft_tau_update) * x1
+ self._config.soft_tau_update * x2,
training_state.target_policy_params,
Expand Down
12 changes: 7 additions & 5 deletions qdax/core/containers/ga_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GARepertoire(Repertoire):
@property
def size(self) -> int:
"""Gives the size of the population."""
first_leaf = jax.tree_leaves(self.genotypes)[0]
first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
return int(first_leaf.shape[0])

def save(self, path: str = "./") -> None:
Expand Down Expand Up @@ -95,7 +95,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey

# sample
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(
subkey, x, shape=(num_samples,), p=p, replace=False
),
Expand All @@ -122,7 +122,7 @@ def add(
"""

# gather individuals and fitnesses
candidates = jax.tree_map(
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
Expand All @@ -138,7 +138,9 @@ def add(
survivor_indices = indices[: self.size]

# keep only the best ones
new_candidates = jax.tree_map(lambda x: x[survivor_indices], candidates)
new_candidates = jax.tree_util.tree_map(
lambda x: x[survivor_indices], candidates
)

new_repertoire = self.replace(
genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
Expand Down Expand Up @@ -172,7 +174,7 @@ def init( # type: ignore
)

# create default genotypes
default_genotypes = jax.tree_map(
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
)

Expand Down
6 changes: 3 additions & 3 deletions qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey
p = (1.0 - repertoire_empty) / jnp.sum(1.0 - repertoire_empty)

random_key, subkey = jax.random.split(random_key)
samples = jax.tree_map(
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(subkey, x, shape=(num_samples,), p=p),
self.genotypes,
)
Expand Down Expand Up @@ -283,7 +283,7 @@ def add(
)

# create new repertoire
new_repertoire_genotypes = jax.tree_map(
new_repertoire_genotypes = jax.tree_util.tree_map(
lambda repertoire_genotypes, new_genotypes: repertoire_genotypes.at[
batch_of_indices.squeeze(axis=-1)
].set(new_genotypes),
Expand Down Expand Up @@ -337,7 +337,7 @@ def init(
# Initialize repertoire with default values
num_centroids = centroids.shape[0]
default_fitnesses = -jnp.inf * jnp.ones(shape=num_centroids)
default_genotypes = jax.tree_map(
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape[1:]),
genotypes,
)
Expand Down
Loading

0 comments on commit ab4d4ca

Please sign in to comment.