Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Diversifier QD Meta Algorithm - JAX backend #52

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions evojax/algo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ We hope this helps EvoJAX users choose the appropriate ones for their experiment
| [CMA-ES](https://arxiv.org/abs/1604.00772) | This is a CMA-ES optimizer using JAX backend, adpated from [this](https://github.com/CyberAgentAILab/cmaes/blob/main/cmaes/_cma.py) faithful implementation of the original CMA-ES algorithm. ([source](https://github.com/google/evojax/blob/main/evojax/algo/cma_jax.py)) | EvoJAX Team | [PR](https://github.com/google/evojax/pull/32) |
| [CR-FM-NES](https://arxiv.org/abs/2201.11422) | This is a CR-FM-NES optimizer using JAX backend, adapted from [this](https://github.com/dietmarwo/fast-cma-es/blob/master/fcmaes/crfmnes.py) implementation of the original CR-FM-NES algorithm. ([performance](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/EvoJax.adoc)) ([source](https://github.com/google/evojax/blob/main/evojax/algo/crfmnes.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/46)
| [CR-FM-NES](https://arxiv.org/abs/2201.11422) | This is a wrapper of the CR-FM-NES C++/Eigen implementation from [fcmaes](https://github.com/dietmarwo/fast-cma-es), using [this](https://github.com/dietmarwo/fast-cma-es/blob/master/_fcmaescpp/crfmnes.cpp) implementation of the original CR-FM-NES algorithm. ([performance](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/EvoJax.adoc)) ([source](https://github.com/google/evojax/blob/main/evojax/algo/fcrfmc.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/44)
| [Diversifier](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/MapElites.adoc) | This is a new QD meta algorithm using JAX backend. Generalization of [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf), based on [this](https://github.com/dietmarwo/fast-cma-es/blob/master/fcmaes/diversifier.py) implementation of the original Diversifier algorithm. ([source](https://github.com/google/evojax/blob/main/evojax/algo/diversifier.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/52)
| [OpenES](https://arxiv.org/pdf/1703.03864.pdf) | This is a wrapper of the OpenES implementation in [evosax](https://github.com/RobertTLange/evosax). ([source](https://github.com/google/evojax/blob/main/evojax/algo/open_es.py)) | [RobertTLange](https://github.com/RobertTLange) | [Table](https://github.com/google/evojax/tree/main/scripts/benchmarks#openes) | |
| [PGPE](https://people.idsia.ch/~juergen/nn2010.pdf) | This implementation is well tested, all examples are based on this algorithm. We provide two optimizers: Adam and [ClipUp](https://github.com/nnaisense/pgpelib). ([source](https://github.com/google/evojax/blob/main/evojax/algo/pgpe.py)) | EvoJAX team | [Table](https://github.com/google/evojax/tree/main/scripts/benchmarks#pgpe) | |
| [PGPE](https://people.idsia.ch/~juergen/nn2010.pdf) | This is a wrapper of the PGPE C++/Eigen implementation from [fcmaes](https://github.com/dietmarwo/fast-cma-es), using [this](https://github.com/dietmarwo/fast-cma-es/blob/master/_fcmaescpp/pgpe.cpp) implementation of the original PGPE algorithm. We provide one optimizer: Adam. ([performance](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/EvoJax.adoc)) ([source](https://github.com/google/evojax/blob/main/evojax/algo/fpgpec.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/51)
Expand Down
3 changes: 3 additions & 0 deletions evojax/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .crfmnes import CRFMNES
from .ars_native import ARS_native
from .fpgpec import FPGPEC
from .diversifier import Diversifier

Strategies = {
"CMA": CMA,
Expand All @@ -44,6 +45,7 @@
"CRFMNES": CRFMNES,
"ARS_native": ARS_native,
"FPGPEC": FPGPEC,
"Diversifier": Diversifier,
}

__all__ = [
Expand All @@ -64,4 +66,5 @@
"Strategies",
"ARS_native",
"FPGPEC",
"Diversifier",
]
55 changes: 24 additions & 31 deletions evojax/algo/crfmnes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def ask(self) -> jnp.ndarray:
return self.jnp_stack(self.params)

def tell(self, fitness: jnp.ndarray) -> None:
self.crfm.tell(-np.array(fitness))
self.crfm.tell(-fitness)
self._best_params = self.crfm.x_best

@property
Expand All @@ -76,9 +76,9 @@ def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None:
class CRFM():
def __init__(self, num_dims:
int, popsize: int,
mean: Optional[Union[jnp.ndarray, np.ndarray]],
mean: Union[jnp.ndarray, np.ndarray],
input_sigma: float,
rng: jax.random.PRNGKey):
key: jax.random.PRNGKey):
"""Fast Moving Natural Evolution Strategy
for High-Dimensional Problems (CR-FM-NES), see https://arxiv.org/abs/2201.11422 .
Derived from https://github.com/nomuramasahir0/crfmnes"""
Expand All @@ -87,9 +87,9 @@ def __init__(self, num_dims:
self.lamb = popsize
self.dim = num_dims
self.sigma = input_sigma
self.rng = rng
self.key, key = jax.random.split(key)
self.m = jnp.array([mean]).T
self.v = jax.random.normal(rng, (self.dim, 1)) / jnp.sqrt(self.dim)
self.v = jax.random.normal(key, (self.dim, 1)) / jnp.sqrt(self.dim)
self.D = jnp.ones([self.dim, 1])

self.w_rank_hat = (jnp.log(self.lamb / 2 + 1) - jnp.log(jnp.arange(1, self.lamb + 1))).reshape(self.lamb, 1)
Expand Down Expand Up @@ -129,23 +129,30 @@ def __init__(self, num_dims:
self.f_best = float('inf')
self.x_best = jnp.empty(self.dim)

hshape = (self.dim, self.lamb // 2)

def generate_population(v, m, sigma, D, key) -> jnp.ndarray:
zkey, key = jax.random.split(key)
zhalf = jax.random.normal(zkey, hshape)
z = jnp.hstack((zhalf, -zhalf))
normv = jnp.linalg.norm(v)
normv2 = normv * normv
vbar = v / normv
y = z + ((jnp.sqrt(1 + normv2) - 1) * jnp.dot(vbar, jnp.dot(vbar.T, z)))
x = m + (sigma * y) * D
return x, y, z, vbar, normv, normv2, key
self._generate_population = jax.jit(generate_population)

def set_m(self, params: jnp.ndarray):
self.m = jnp.array(params).reshape((self.dim, 1))

def ask(self) -> jnp.ndarray:
key, self.rng = jax.random.split(self.rng)
zhalf = jax.random.normal(key, (self.dim, int(self.lamb / 2)))
self.z = self.z.at[:, self.idxp].set(zhalf)
self.z = self.z.at[:, self.idxm].set(-zhalf)
self.normv = jnp.linalg.norm(self.v)
self.normv2 = self.normv ** 2
self.vbar = self.v / self.normv
self.y = self.z + ((jnp.sqrt(1 + self.normv2) - 1) * jnp.dot(self.vbar, jnp.dot(self.vbar.T, self.z)))
self.x = self.m + (self.sigma * self.y) * self.D
self.x, self.y, self.z, self.vbar, self.normv, self.normv2, self.key = \
self._generate_population(self.v, self.m, self.sigma, self.D, self.key)
return self.x.T

def tell(self, evals_no_sort: np.ndarray) -> None:
sorted_indices = sort_indices_by(evals_no_sort, self.z)
def tell(self, evals_no_sort: jnp.ndarray) -> None:
sorted_indices = jnp.argsort(evals_no_sort)
best_eval_id = sorted_indices[0]
f_best = evals_no_sort[best_eval_id]
x_best = self.x[:, best_eval_id]
Expand All @@ -156,8 +163,7 @@ def tell(self, evals_no_sort: np.ndarray) -> None:
self.g += 1
if f_best < self.f_best:
self.f_best = f_best
self.x_best = x_best

self.x_best = x_best
# This operation assumes that if the solution is infeasible, infinity comes in as input.
lambF = jnp.sum(evals_no_sort < jnp.finfo(float).max)
# evolution path p_sigma
Expand Down Expand Up @@ -224,16 +230,3 @@ def get_h_inv(dim: int) -> float:

def exp(a: float) -> float:
return math.exp(min(100, a)) # avoid overflow

def sort_indices_by(evals: np.ndarray, z: jnp.ndarray) -> jnp.ndarray:
lam = len(evals)
sorted_indices = np.argsort(evals)
sorted_evals = evals[sorted_indices]
no_of_feasible_solutions = np.where(sorted_evals != jnp.inf)[0].size
if no_of_feasible_solutions != lam:
infeasible_z = z[:, np.where(evals == jnp.inf)[0]]
distances = np.sum(infeasible_z ** 2, axis=0)
infeasible_indices = sorted_indices[no_of_feasible_solutions:]
indices_sorted_by_distance = np.argsort(distances)
sorted_indices = sorted_indices.at[no_of_feasible_solutions:].set(infeasible_indices[indices_sorted_by_distance])
return sorted_indices
155 changes: 155 additions & 0 deletions evojax/algo/diversifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2022 The EvoJAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the Diversifier algorithm in JAX.

Diversifier is a JAX based meta-algorithm which is a generalization of CMA-ME
(see https://arxiv.org/pdf/1912.02400.pdf). It uses a MAP-Elites
archive not for solution candidate generation, but only to modify
the fitness values told (via tell) to the wrapped algorithm.
This modification changes the fitness ranking of the population
to favor exploration over exploitation. Tested with
CR-FM-NES and CMA-ES, but other wrapped algorithms may work as well.
Based on https://github.com/dietmarwo/fast-cma-es/blob/master/fcmaes/diversifier.py
(see https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/MapElites.adoc)

For Brax Ant CR-FM-NES-ME (wrapping CR-FM-NES), compared with MAP-Elites, reaches
a higher QD-score (sum of fitness values of all elites in the map) for high
iteration numbers. Use MAP-Elites instead for a low evaluation budget or if you
want to maximize the number of occupied niches.
"""

import logging
import numpy as np

from typing import Union

import jax
import jax.numpy as jnp

from evojax.algo.base import QualityDiversityMethod, NEAlgorithm
from evojax.task.base import TaskState
from evojax.task.base import BDExtractor
from evojax.util import create_logger

from time import time

class Diversifier(QualityDiversityMethod):
"""The Diversifier meta algorithm."""

def __init__(self,
solver: NEAlgorithm,
pop_size: int,
param_size: int,
bd_extractor: BDExtractor,
fitness_weight: float = 0,
seed: int = 0,
logger: logging.Logger = None):
"""Initialization function.

Args:
solver: wrapped solver, CR-FM-NES and CMA-ES work well.
pop_size - Population size.
param_size - Parameter size.
bd_extractor - A list of behavior descriptor extractors.
fitness_weight - factor applied to the original fitness.
Should be in interval [0,1]. Higher value means:
QD-score grows faster, but will stop growing earlier.
Choose fitness_weight = 1 if you are also interested in
a good global optimum.
seed - Random seed for parameters sampling.
logger - Logging utility.
"""

if logger is None:
self._logger = create_logger(name=solver._logger.name + '-ME')
else:
self._logger = logger

self.solver = solver
self.pop_size = abs(pop_size)
self.param_size = param_size
self.fitness_weight = fitness_weight
self.bd_names = [x[0] for x in bd_extractor.bd_spec]
self.bd_n_bins = [x[1] for x in bd_extractor.bd_spec]
self.params_lattice = jnp.zeros((np.prod(self.bd_n_bins), param_size))
self.fitness_lattice = -float('Inf') * jnp.ones(np.prod(self.bd_n_bins))
self.occupancy_lattice = jnp.zeros(
np.prod(self.bd_n_bins), dtype=jnp.int32)
self.population = None
self.bin_idx = jnp.zeros(self.pop_size, dtype=jnp.int32)
self.key = jax.random.PRNGKey(seed)

def get_bin_idx(task_state):
bd_idx = [
task_state.__dict__[name].astype(int) for name in self.bd_names]
return jnp.ravel_multi_index(bd_idx, self.bd_n_bins, mode='clip')
self._get_bin_idx = jax.jit(jax.vmap(get_bin_idx))

def update_fitness_and_param(
target_bin, bin_idx,
fitness, fitness_lattice, param, param_lattice):
best_ix = jnp.where(
bin_idx == target_bin, fitness, fitness_lattice.min()).argmax()
best_fitness = fitness[best_ix]
new_fitness_lattice = jnp.where(
best_fitness > fitness_lattice[target_bin],
best_fitness, fitness_lattice[target_bin])
new_param_lattice = jnp.where(
best_fitness > fitness_lattice[target_bin],
param[best_ix], param_lattice[target_bin])
return new_fitness_lattice, new_param_lattice
self._update_lattices = jax.jit(jax.vmap(
update_fitness_and_param,
in_axes=(0, None, None, None, None, None)))

def get_to_tell(
fitness, lattice_fitness, fitness_weight):
improvement = fitness - lattice_fitness
max_valid = jnp.amax(jnp.where(improvement == np.inf, -np.inf, improvement))
norm_fitness = fitness - jnp.amin(fitness) + 1E-9
improvement = jnp.where(improvement == np.inf, max_valid + norm_fitness, improvement)
to_tell = jnp.where(fitness_weight <= 0, improvement,
improvement + fitness_weight * fitness)
return to_tell
self._get_to_tell = jax.jit(get_to_tell)

def ask(self) -> jnp.ndarray:
self.population = self.solver.ask()
return self.population

def observe_bd(self, task_state: TaskState) -> None:
self.bin_idx = self._get_bin_idx(task_state)

def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None:
lattice_fitness = self.fitness_lattice[self.bin_idx]
to_tell = self._get_to_tell(fitness, lattice_fitness, self.fitness_weight)
# tell the wrapped solver the modified fitness including improvement relative to the lattice
self.solver.tell(to_tell)
# update lattice
unique_bins = jnp.unique(self.bin_idx)
fitness_lattice, params_lattice = self._update_lattices(
unique_bins, self.bin_idx,
fitness, self.fitness_lattice,
self.population, self.params_lattice)
self.occupancy_lattice = self.occupancy_lattice.at[unique_bins].set(1)
self.fitness_lattice = self.fitness_lattice.at[unique_bins].set(
fitness_lattice)
self.params_lattice = self.params_lattice.at[unique_bins].set(
params_lattice)

@property
def best_params(self) -> jnp.ndarray:
ix = jnp.argmax(self.fitness_lattice, axis=0)
return self.params_lattice[ix]
Loading