Skip to content

Commit

Permalink
add extra scores to ME and MELS
Browse files Browse the repository at this point in the history
  • Loading branch information
Lookatator committed Feb 17, 2025
1 parent 0c354fd commit 359fee8
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 12 deletions.
69 changes: 58 additions & 11 deletions qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

import warnings
from typing import Callable, List, Optional, Tuple, Union

import jax
Expand Down Expand Up @@ -156,6 +155,8 @@ class MapElitesRepertoire(GARepertoire):
is (num_centroids, num_descriptors).
centroids: an array that contains the centroids of the tessellation. The array
shape is (num_centroids, num_descriptors).
extra_scores: extra scores resulting from the evaluation of the genotypes
keys_extra_scores: keys of the extra scores to store in the repertoire
"""

descriptors: Descriptor
Expand Down Expand Up @@ -242,12 +243,16 @@ def add(
aforementioned genotypes. Its shape is (batch_size, num_descriptors)
batch_of_fitnesses: an array that contains the fitnesses of the
aforementioned genotypes. Its shape is (batch_size,)
batch_of_extra_scores: unused tree that contains the extra_scores of
batch_of_extra_scores: tree that contains the extra_scores of
aforementioned genotypes.
Returns:
The updated MAP-Elites repertoire.
"""
if batch_of_extra_scores is None:
batch_of_extra_scores = {}

filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

batch_of_indices = get_cells_indices(batch_of_descriptors, self.centroids)
batch_of_indices = jnp.expand_dims(batch_of_indices, axis=-1)
Expand Down Expand Up @@ -297,11 +302,22 @@ def add(
batch_of_descriptors
)

# update extra scores
new_extra_scores = jax.tree.map(
lambda repertoire_scores, new_scores: repertoire_scores.at[
batch_of_indices.squeeze(axis=-1)
].set(new_scores),
self.extra_scores,
filtered_batch_of_extra_scores,
)

return MapElitesRepertoire(
genotypes=new_repertoire_genotypes,
fitnesses=new_fitnesses,
descriptors=new_descriptors,
centroids=self.centroids,
extra_scores=new_extra_scores,
keys_extra_scores=self.keys_extra_scores,
)

@classmethod
Expand All @@ -311,7 +327,10 @@ def init( # type: ignore
fitnesses: Fitness,
descriptors: Descriptor,
centroids: Centroid,
*args,
extra_scores: Optional[ExtraScores] = None,
keys_extra_scores: Tuple[str, ...] = (),
**kwargs,
) -> MapElitesRepertoire:
"""
Initialize a Map-Elites repertoire with an initial population of genotypes.
Expand All @@ -328,24 +347,33 @@ def init( # type: ignore
descriptors: descriptors of the initial genotypes
of shape (batch_size, num_descriptors)
centroids: tessellation centroids of shape (batch_size, num_descriptors)
extra_scores: unused extra_scores of the initial genotypes
extra_scores: extra scores of the initial genotypes
keys_extra_scores: keys of the extra scores to store in the repertoire
Returns:
an initialized MAP-Elite repertoire
"""
warnings.warn(
(
"This type of repertoire does not store the extra scores "
"computed by the scoring function"
),
stacklevel=2,
)

if extra_scores is None:
extra_scores = {}

extra_scores = {
key: value
for key, value in extra_scores.items()
if key in keys_extra_scores
}

# retrieve one genotype from the population
first_genotype = jax.tree.map(lambda x: x[0], genotypes)
first_extra_scores = jax.tree.map(lambda x: x[0], extra_scores)

# create a repertoire with default values
repertoire = cls.init_default(genotype=first_genotype, centroids=centroids)
repertoire = cls.init_default(
genotype=first_genotype,
centroids=centroids,
one_extra_score=first_extra_scores,
keys_extra_scores=keys_extra_scores,
)

# add initial population to the repertoire
new_repertoire = repertoire.add(genotypes, descriptors, fitnesses, extra_scores)
Expand All @@ -357,6 +385,8 @@ def init_default(
cls,
genotype: Genotype,
centroids: Centroid,
one_extra_score: Optional[ExtraScores] = None,
keys_extra_scores: Tuple[str, ...] = (),
) -> MapElitesRepertoire:
"""Initialize a Map-Elites repertoire with an initial population of
genotypes. Requires the definition of centroids that can be computed
Expand All @@ -368,10 +398,19 @@ def init_default(
Args:
genotype: the typical genotype that will be stored.
centroids: the centroids of the repertoire
keys_extra_scores: keys of the extra scores to store in the repertoire
Returns:
A repertoire filled with default values.
"""
if one_extra_score is None:
one_extra_score = {}

one_extra_score = {
key: value
for key, value in one_extra_score.items()
if key in keys_extra_scores
}

# get number of centroids
num_centroids = centroids.shape[0]
Expand All @@ -388,9 +427,17 @@ def init_default(
# default descriptor is all zeros
default_descriptors = jnp.zeros_like(centroids)

# default extra scores is empty dict
default_extra_scores = jax.tree.map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
one_extra_score,
)

return cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
extra_scores=default_extra_scores,
keys_extra_scores=keys_extra_scores,
)
39 changes: 38 additions & 1 deletion qdax/core/containers/mels_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from __future__ import annotations

from typing import Callable, Optional
from typing import Callable, Optional, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -186,6 +186,12 @@ def add(
Returns:
The updated repertoire.
"""

if batch_of_extra_scores is None:
batch_of_extra_scores = {}

filtered_batch_of_extra_scores = self.filter_extra_scores(batch_of_extra_scores)

batch_size, num_samples = batch_of_fitnesses.shape

# Compute indices/cells of all descriptors.
Expand Down Expand Up @@ -260,9 +266,20 @@ def add(
batch_of_spreads.squeeze(axis=-1)
)

# update extra scores
new_extra_scores = jax.tree.map(
lambda repertoire_scores, new_scores: repertoire_scores.at[
batch_of_indices.squeeze(axis=-1)
].set(new_scores),
self.extra_scores,
filtered_batch_of_extra_scores,
)

return MELSRepertoire(
genotypes=new_repertoire_genotypes,
fitnesses=new_fitnesses,
extra_scores=new_extra_scores,
keys_extra_scores=self.keys_extra_scores,
descriptors=new_descriptors,
centroids=self.centroids,
spreads=new_spreads,
Expand All @@ -273,6 +290,8 @@ def init_default(
cls,
genotype: Genotype,
centroids: Centroid,
one_extra_score: Optional[ExtraScores] = None,
keys_extra_scores: Tuple[str, ...] = (),
) -> MELSRepertoire:
"""Initialize a MAP-Elites Low-Spread repertoire with an initial population of
genotypes. Requires the definition of centroids that can be computed with any
Expand All @@ -284,10 +303,20 @@ def init_default(
Args:
genotype: the typical genotype that will be stored.
centroids: the centroids of the repertoire.
extra_scores: extra scores to store in the repertoire
keys_extra_scores: keys of the extra scores to store in the repertoire
Returns:
A repertoire filled with default values.
"""
if one_extra_score is None:
one_extra_score = {}

one_extra_score = {
key: value
for key, value in one_extra_score.items()
if key in keys_extra_scores
}

# get number of centroids
num_centroids = centroids.shape[0]
Expand All @@ -307,10 +336,18 @@ def init_default(
# default spread is inf so that any spread will be less
default_spreads = jnp.full(shape=num_centroids, fill_value=jnp.inf)

# default extra scores is empty dict
default_extra_scores = jax.tree.map(
lambda x: jnp.zeros(shape=(num_centroids,) + x.shape, dtype=x.dtype),
one_extra_score,
)

return cls(
genotypes=default_genotypes,
fitnesses=default_fitnesses,
descriptors=default_descriptors,
centroids=centroids,
spreads=default_spreads,
extra_scores=default_extra_scores,
keys_extra_scores=keys_extra_scores,
)

0 comments on commit 359fee8

Please sign in to comment.