Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default Scoring Functions for Sphere, Rastrigin, Arm and Brax environments #73

Merged
merged 56 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
a75970b
first commit - push a draft of the strcuture and arm and rastrigin tasks
limbryan Aug 1, 2022
fb5c5a2
complete arm and standard functions - rastrigin, sphere, rastrigin-proj
limbryan Aug 1, 2022
e314310
adding create_brax_scoring_function_fn to make scoring functions for …
Lookatator Aug 1, 2022
0aa9d55
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Aug 1, 2022
418d2bc
add test for default standard functions
limbryan Aug 2, 2022
c28ec64
add test for default arm task and test make better docstring for arm
limbryan Aug 2, 2022
ab7edd0
add rastrigin proj test as well
limbryan Aug 3, 2022
4f38cb4
fix weird styling issues to do with types for the tests script
limbryan Aug 3, 2022
ff4847c
adding function for init controllers, and a function for creating def…
Lookatator Aug 4, 2022
34ac363
simpler metrics_fn + fix type of bd_extractor
Lookatator Aug 4, 2022
15b931c
fix catch of return of init_population_controllers
Lookatator Aug 4, 2022
2ce365e
add Docstrings
Lookatator Aug 8, 2022
a24ff3c
fix metadata notebook
Lookatator Aug 9, 2022
7323ce7
fix import scoring function in notebook
Lookatator Aug 9, 2022
3c83225
add init file into tasks directory, create an examples folder for scr…
limbryan Aug 11, 2022
5d5ded3
update README to be a usable readme that uses the arm function instea…
limbryan Aug 11, 2022
d62baa6
upload new benchmark functions from QDbenchmark workshop with corresp…
limbryan Aug 19, 2022
c750ae1
add README for summary of tasks in the directory along wiht some desc…
limbryan Aug 19, 2022
c7f5f64
fix implementation hypervolume functions
Lookatator Aug 21, 2022
b1fb1f1
adding run_me() example function
Lookatator Aug 21, 2022
3a04dc4
adding basis functions qd_suite benchmarking
Lookatator Aug 22, 2022
632ff41
fix keymanagement in noisy arm and add option to add noise on params …
limbryan Aug 22, 2022
e44c7b8
Merge branch 'feat/default_tasks' of https://github.com/adaptive-inte…
limbryan Aug 22, 2022
93fdadd
add draft of tasks summary README
Aug 22, 2022
338d7c2
refactoring benchmarking functions
Lookatator Aug 22, 2022
fa1acf0
fix key splitting in noisy arm task and references to hypervolume fun…
Aug 22, 2022
03a7702
improve plotting plot_multidimensional_map_elites_grid when using hig…
Lookatator Aug 22, 2022
b4a5a31
add default tasks qd_suite
Lookatator Aug 22, 2022
53c206a
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Aug 22, 2022
460b5a8
completing description of QD Suite functions
Lookatator Aug 22, 2022
22ffba8
move all qd_suite tasks
Lookatator Aug 22, 2022
3639982
fix README latex
Lookatator Aug 22, 2022
3e38f11
fix README latex
Lookatator Aug 23, 2022
c40b939
move default tasks to qd_suite __init__
Lookatator Aug 23, 2022
b865988
add example usage qd_suite tasks
Lookatator Aug 23, 2022
6dab609
add example usage hypervolume functions
Lookatator Aug 23, 2022
57d0d51
add examples for standard function
Lookatator Aug 23, 2022
3c1419c
example BRAX usage
Lookatator Aug 23, 2022
41831a5
add test for qd suite tasks
Aug 23, 2022
7371d35
Merge branch 'feat/default_tasks' of https://github.com/adaptive-inte…
Aug 23, 2022
5be8b0d
add test for qd suite tasks
Aug 23, 2022
f33ea17
remove type aliases
Lookatator Sep 7, 2022
4b1f333
specify type of grid_shape
Lookatator Sep 7, 2022
6e0baa7
Merge remote-tracking branch 'origin/develop' into feat/default_tasks
Lookatator Sep 7, 2022
73d32ce
fix styling issue
Lookatator Sep 7, 2022
75a456c
add Docstrings default scores
Lookatator Sep 8, 2022
c299e4f
Mention QDax tasks doc in the main README
Lookatator Sep 8, 2022
1908362
stochastic arm -> noisy arm
Lookatator Sep 9, 2022
169d1ac
add task specific test for arm
limbryan Oct 11, 2022
f5b88fb
Merge branch 'develop' into feat/default_tasks
Lookatator Oct 11, 2022
ed52ea6
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Oct 11, 2022
57e0caa
reformat tests/default_tasks_test/arm_test.py
Lookatator Oct 11, 2022
1f08215
complete missing descriptor bounds
felixchalumeau Oct 12, 2022
fd98af0
QDSuiteTask inherits from abc.ABC instead of using ABCMeta
Lookatator Oct 12, 2022
76cc8a5
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Lookatator Oct 12, 2022
1de5311
create examples folder with notebooks and scripts
Lookatator Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,71 @@ For a full and interactive example to see how QDax works, we recommend starting

However, a summary of the main API usage is provided below:
```python
import qdax
import jax
import functools
from qdax.core.map_elites import MAPElites
from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.tasks.arm import arm_scoring_function
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.utils.metrics import default_qd_metrics

seed = 42
num_param_dimensions = 100 # num DoF arm
init_batch_size = 100
batch_size = 1024
num_iterations = 50
grid_shape = (100, 100)
min_param = 0.0
max_param = 1.0
min_bd = 0.0
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
init_variables = jax.random.uniform(
subkey,
shape=(init_batch_size, num_param_dimensions),
minval=min_param,
maxval=max_param,
)

# Define emitter
variation_fn = functools.partial(
isoline_variation,
iso_sigma=0.05,
line_sigma=0.1,
minval=min_param,
maxval=max_param,
)
mixing_emitter = MixingEmitter(
mutation_fn=lambda x, y: (x, y),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size,
)

# Define a metrics function
metrics_fn = functools.partial(
default_qd_metrics,
qd_offset=0.0,
)

# Instantiate MAP-Elites
map_elites = MAPElites(
scoring_function=scoring_fn,
scoring_function=arm_scoring_function,
emitter=mixing_emitter,
metrics_function=metrics_function,
metrics_function=metrics_fn,
)

# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=grid_shape,
minval=min_bd,
maxval=max_bd,
)

# Initializes repertoire and emitter state
Expand Down Expand Up @@ -77,6 +134,11 @@ The QDax library also provides implementations for some useful baseline algorith
| [NSGA2](https://ieeexplore.ieee.org/document/996017) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |
| [SPEA2](https://www.semanticscholar.org/paper/SPEA2%3A-Improving-the-strength-pareto-evolutionary-Zitzler-Laumanns/b13724cb54ae4171916f3f969d304b9e9752a57f) | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adaptive-intelligent-robotics/QDax/blob/main/notebooks/nsga2_spea2_example.ipynb) |

## QDax Tasks
The QDax library also provides numerous implementations for several standard Quality-Diversity tasks.

All those implementations, and their descriptions are provided in the [tasks directory](./qdax/tasks).

## Contributing
Issues and contributions are welcome. Please refer to the [contribution guide](https://qdax.readthedocs.io/en/latest/guides/CONTRIBUTING/) in the documentation for more details.

Expand Down
103 changes: 103 additions & 0 deletions examples/me_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import functools

import jax
import matplotlib.pyplot as plt

from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
from qdax.core.emitters.mutation_operators import isoline_variation
from qdax.core.emitters.standard_emitters import MixingEmitter
from qdax.core.map_elites import MAPElites
from qdax.tasks.arm import arm_scoring_function
from qdax.utils.metrics import default_qd_metrics
from qdax.utils.plotting import plot_2d_map_elites_repertoire


def run_me() -> None:
seed = 42
num_param_dimensions = 8 # num DoF arm
init_batch_size = 100
batch_size = 2048
num_evaluations = int(1e6)
num_iterations = num_evaluations // batch_size
grid_shape = (100, 100)
min_param = 0.0
max_param = 1.0
min_bd = 0.0
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
init_variables = jax.random.uniform(
subkey,
shape=(init_batch_size, num_param_dimensions),
minval=min_param,
maxval=max_param,
)

# Define emitter
variation_fn = functools.partial(
isoline_variation,
iso_sigma=0.005,
line_sigma=0,
minval=min_param,
maxval=max_param,
)
mixing_emitter = MixingEmitter(
mutation_fn=lambda x, y: (x, y),
variation_fn=variation_fn,
variation_percentage=1.0,
batch_size=batch_size,
)

# Define a metrics function
metrics_fn = functools.partial(
default_qd_metrics,
qd_offset=0.0,
)

# Instantiate MAP-Elites
map_elites = MAPElites(
scoring_function=arm_scoring_function,
emitter=mixing_emitter,
metrics_function=metrics_fn,
)

# Compute the centroids
centroids = compute_euclidean_centroids(
grid_shape=grid_shape,
minval=min_bd,
maxval=max_bd,
)

# Initializes repertoire and emitter state
repertoire, emitter_state, random_key = map_elites.init(
init_variables, centroids, random_key
)

# Run MAP-Elites loop
for _ in range(num_iterations):
(repertoire, emitter_state, metrics, random_key,) = map_elites.update(
repertoire,
emitter_state,
random_key,
)

# plot archive
fig, axes = plot_2d_map_elites_repertoire(
centroids=repertoire.centroids,
repertoire_fitnesses=repertoire.fitnesses,
minval=min_bd,
maxval=max_bd,
repertoire_descriptors=repertoire.descriptors,
# vmin=-0.2,
# vmax=0.0,
)

plt.show()


if __name__ == "__main__":
run_me()
2 changes: 1 addition & 1 deletion notebooks/mapelites_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
"from qdax.core.map_elites import MAPElites\n",
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids, MapElitesRepertoire\n",
"from qdax import environments\n",
"from qdax.core.neuroevolution.mdp_utils import scoring_function\n",
"from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
"from qdax.core.neuroevolution.networks.networks import MLP\n",
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/pgame_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"from qdax.core.map_elites import MAPElites\n",
"from qdax.core.containers.mapelites_repertoire import compute_cvt_centroids\n",
"from qdax import environments\n",
"from qdax.core.neuroevolution.mdp_utils import scoring_function\n",
"from qdax.tasks.brax_envs import scoring_function_brax_envs as scoring_function\n",
"from qdax.core.neuroevolution.buffers.buffer import QDTransition\n",
"from qdax.core.neuroevolution.networks.networks import MLP\n",
"from qdax.core.emitters.mutation_operators import isoline_variation\n",
Expand Down
155 changes: 32 additions & 123 deletions qdax/core/neuroevolution/mdp_utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,15 @@
from functools import partial
from typing import Any, Callable, Tuple

import brax
import brax.envs
import flax.linen as nn
import jax
import jax.numpy as jnp
from brax.envs import State as EnvState
from flax.struct import PyTreeNode

from qdax.core.neuroevolution.buffers.buffer import (
QDTransition,
ReplayBuffer,
Transition,
)
from qdax.types import (
Descriptor,
ExtraScores,
Fitness,
Genotype,
Metrics,
Params,
RNGKey,
)
from qdax.core.neuroevolution.buffers.buffer import ReplayBuffer, Transition
from qdax.types import Genotype, Metrics, Params, RNGKey


class TrainingState(PyTreeNode):
Expand Down Expand Up @@ -125,114 +114,6 @@ def _scan_play_step_fn(
return state, transitions


@partial(
jax.jit,
static_argnames=(
"episode_length",
"play_step_fn",
"behavior_descriptor_extractor",
),
)
def scoring_function(
policies_params: Genotype,
random_key: RNGKey,
init_states: brax.envs.State,
episode_length: int,
play_step_fn: Callable[
[EnvState, Params, RNGKey, brax.envs.Env],
Tuple[EnvState, Params, RNGKey, QDTransition],
],
behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor],
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
"""Evaluates policies contained in policies_params in parallel in
deterministic or pseudo-deterministic environments.

This rollout is only deterministic when all the init states are the same.
If the init states are fixed but different, as a policy is not necessarly
evaluated with the same environment everytime, this won't be determinist.
When the init states are different, this is not purely stochastic.
"""

# Perform rollouts with each policy
random_key, subkey = jax.random.split(random_key)
unroll_fn = partial(
generate_unroll,
episode_length=episode_length,
play_step_fn=play_step_fn,
random_key=subkey,
)

_final_state, data = jax.vmap(unroll_fn)(init_states, policies_params)

# create a mask to extract data properly
is_done = jnp.clip(jnp.cumsum(data.dones, axis=1), 0, 1)
mask = jnp.roll(is_done, 1, axis=1)
mask = mask.at[:, 0].set(0)

# Scores - add offset to ensure positive fitness (through positive rewards)
fitnesses = jnp.sum(data.rewards * (1.0 - mask), axis=1)
descriptors = behavior_descriptor_extractor(data, mask)

return (
fitnesses,
descriptors,
{
"transitions": data,
},
random_key,
)


@partial(
jax.jit,
static_argnames=(
"episode_length",
"play_reset_fn",
"play_step_fn",
"behavior_descriptor_extractor",
),
)
def reset_based_scoring_function(
policies_params: Genotype,
random_key: RNGKey,
episode_length: int,
play_reset_fn: Callable[[RNGKey], brax.envs.State],
play_step_fn: Callable[
[brax.envs.State, Params, RNGKey, brax.envs.Env],
Tuple[brax.envs.State, Params, RNGKey, QDTransition],
],
behavior_descriptor_extractor: Callable[[QDTransition, jnp.ndarray], Descriptor],
) -> Tuple[Fitness, Descriptor, ExtraScores, RNGKey]:
"""Evaluates policies contained in policies_params in parallel.
The play_reset_fn function allows for a more general scoring_function that can be
called with different batch-size and not only with a batch-size of the same
dimension as init_states.

To define purely stochastic environments, using the reset function from the
environment, use "play_reset_fn = env.reset".

To define purely deterministic environments, as in "scoring_function", generate
a single init_state using "init_state = env.reset(random_key)", then use
"play_reset_fn = lambda random_key: init_state".
"""

random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, jax.tree_leaves(policies_params)[0].shape[0])
reset_fn = jax.vmap(play_reset_fn)
init_states = reset_fn(keys)

fitnesses, descriptors, extra_scores, random_key = scoring_function(
policies_params=policies_params,
random_key=random_key,
init_states=init_states,
episode_length=episode_length,
play_step_fn=play_step_fn,
behavior_descriptor_extractor=behavior_descriptor_extractor,
)

return (fitnesses, descriptors, extra_scores, random_key)


@partial(
jax.jit,
static_argnames=(
Expand Down Expand Up @@ -315,3 +196,31 @@ def mask_episodes(x: jnp.ndarray) -> jnp.ndarray:
return jnp.where(mask.T, x.T, jnp.nan * jnp.ones_like(x).T).T

return jax.tree_map(mask_episodes, transition) # type: ignore


def init_population_controllers(
policy_network: nn.Module,
env: brax.envs.Env,
batch_size: int,
random_key: RNGKey,
) -> Tuple[Genotype, RNGKey]:
"""
Initializes the population of controllers using a policy_network.

Args:
policy_network: The policy network structure used for creating policy
controllers.
env: the BRAX environment.
batch_size: the number of environments we play simultaneously.
random_key: a JAX random key.

Returns:
A tuple of the initial population and the new random key.
"""
random_key, subkey = jax.random.split(random_key)

keys = jax.random.split(subkey, num=batch_size)
fake_batch = jnp.zeros(shape=(batch_size, env.observation_size))
init_variables = jax.vmap(policy_network.init)(keys, fake_batch)

return init_variables, random_key
Loading