diff --git a/evojax/algo/README.md b/evojax/algo/README.md index e2c244b9..ced3b952 100644 --- a/evojax/algo/README.md +++ b/evojax/algo/README.md @@ -11,6 +11,7 @@ We hope this helps EvoJAX users choose the appropriate ones for their experiment | [CMA-ES](https://arxiv.org/abs/1604.00772) | This is a CMA-ES optimizer using JAX backend, adpated from [this](https://github.com/CyberAgentAILab/cmaes/blob/main/cmaes/_cma.py) faithful implementation of the original CMA-ES algorithm. ([source](https://github.com/google/evojax/blob/main/evojax/algo/cma_jax.py)) | EvoJAX Team | [PR](https://github.com/google/evojax/pull/32) | | [CR-FM-NES](https://arxiv.org/abs/2201.11422) | This is a CR-FM-NES optimizer using JAX backend, adapted from [this](https://github.com/dietmarwo/fast-cma-es/blob/master/fcmaes/crfmnes.py) implementation of the original CR-FM-NES algorithm. ([performance](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/EvoJax.adoc)) ([source](https://github.com/google/evojax/blob/main/evojax/algo/crfmnes.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/46) | [CR-FM-NES](https://arxiv.org/abs/2201.11422) | This is a wrapper of the CR-FM-NES C++/Eigen implementation from [fcmaes](https://github.com/dietmarwo/fast-cma-es), using [this](https://github.com/dietmarwo/fast-cma-es/blob/master/_fcmaescpp/crfmnes.cpp) implementation of the original CR-FM-NES algorithm. ([performance](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/EvoJax.adoc)) ([source](https://github.com/google/evojax/blob/main/evojax/algo/fcrfmc.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/44) +| [Diversifier](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/MapElites.adoc) | This is a new QD meta algorithm using JAX backend. Generalization of [CMA-ME](https://arxiv.org/pdf/1912.02400.pdf), based on [this](https://github.com/dietmarwo/fast-cma-es/blob/master/fcmaes/diversifier.py) implementation of the original Diversifier algorithm. ([source](https://github.com/google/evojax/blob/main/evojax/algo/diversifier.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/52) | [OpenES](https://arxiv.org/pdf/1703.03864.pdf) | This is a wrapper of the OpenES implementation in [evosax](https://github.com/RobertTLange/evosax). ([source](https://github.com/google/evojax/blob/main/evojax/algo/open_es.py)) | [RobertTLange](https://github.com/RobertTLange) | [Table](https://github.com/google/evojax/tree/main/scripts/benchmarks#openes) | | | [PGPE](https://people.idsia.ch/~juergen/nn2010.pdf) | This implementation is well tested, all examples are based on this algorithm. We provide two optimizers: Adam and [ClipUp](https://github.com/nnaisense/pgpelib). ([source](https://github.com/google/evojax/blob/main/evojax/algo/pgpe.py)) | EvoJAX team | [Table](https://github.com/google/evojax/tree/main/scripts/benchmarks#pgpe) | | | [PGPE](https://people.idsia.ch/~juergen/nn2010.pdf) | This is a wrapper of the PGPE C++/Eigen implementation from [fcmaes](https://github.com/dietmarwo/fast-cma-es), using [this](https://github.com/dietmarwo/fast-cma-es/blob/master/_fcmaescpp/pgpe.cpp) implementation of the original PGPE algorithm. We provide one optimizer: Adam. ([performance](https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/EvoJax.adoc)) ([source](https://github.com/google/evojax/blob/main/evojax/algo/fpgpec.py)) | [dietmarwo](https://github.com/dietmarwo) | [PR](https://github.com/google/evojax/pull/51) diff --git a/evojax/algo/__init__.py b/evojax/algo/__init__.py index 8942d8f1..e405c693 100644 --- a/evojax/algo/__init__.py +++ b/evojax/algo/__init__.py @@ -28,6 +28,7 @@ from .crfmnes import CRFMNES from .ars_native import ARS_native from .fpgpec import FPGPEC +from .diversifier import Diversifier Strategies = { "CMA": CMA, @@ -44,6 +45,7 @@ "CRFMNES": CRFMNES, "ARS_native": ARS_native, "FPGPEC": FPGPEC, + "Diversifier": Diversifier, } __all__ = [ @@ -64,4 +66,5 @@ "Strategies", "ARS_native", "FPGPEC", + "Diversifier", ] diff --git a/evojax/algo/crfmnes.py b/evojax/algo/crfmnes.py index 21cc36f3..6fae3fd6 100644 --- a/evojax/algo/crfmnes.py +++ b/evojax/algo/crfmnes.py @@ -61,7 +61,7 @@ def ask(self) -> jnp.ndarray: return self.jnp_stack(self.params) def tell(self, fitness: jnp.ndarray) -> None: - self.crfm.tell(-np.array(fitness)) + self.crfm.tell(-fitness) self._best_params = self.crfm.x_best @property @@ -76,9 +76,9 @@ def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None: class CRFM(): def __init__(self, num_dims: int, popsize: int, - mean: Optional[Union[jnp.ndarray, np.ndarray]], + mean: Union[jnp.ndarray, np.ndarray], input_sigma: float, - rng: jax.random.PRNGKey): + key: jax.random.PRNGKey): """Fast Moving Natural Evolution Strategy for High-Dimensional Problems (CR-FM-NES), see https://arxiv.org/abs/2201.11422 . Derived from https://github.com/nomuramasahir0/crfmnes""" @@ -87,9 +87,9 @@ def __init__(self, num_dims: self.lamb = popsize self.dim = num_dims self.sigma = input_sigma - self.rng = rng + self.key, key = jax.random.split(key) self.m = jnp.array([mean]).T - self.v = jax.random.normal(rng, (self.dim, 1)) / jnp.sqrt(self.dim) + self.v = jax.random.normal(key, (self.dim, 1)) / jnp.sqrt(self.dim) self.D = jnp.ones([self.dim, 1]) self.w_rank_hat = (jnp.log(self.lamb / 2 + 1) - jnp.log(jnp.arange(1, self.lamb + 1))).reshape(self.lamb, 1) @@ -129,23 +129,30 @@ def __init__(self, num_dims: self.f_best = float('inf') self.x_best = jnp.empty(self.dim) + hshape = (self.dim, self.lamb // 2) + + def generate_population(v, m, sigma, D, key) -> jnp.ndarray: + zkey, key = jax.random.split(key) + zhalf = jax.random.normal(zkey, hshape) + z = jnp.hstack((zhalf, -zhalf)) + normv = jnp.linalg.norm(v) + normv2 = normv * normv + vbar = v / normv + y = z + ((jnp.sqrt(1 + normv2) - 1) * jnp.dot(vbar, jnp.dot(vbar.T, z))) + x = m + (sigma * y) * D + return x, y, z, vbar, normv, normv2, key + self._generate_population = jax.jit(generate_population) + def set_m(self, params: jnp.ndarray): self.m = jnp.array(params).reshape((self.dim, 1)) def ask(self) -> jnp.ndarray: - key, self.rng = jax.random.split(self.rng) - zhalf = jax.random.normal(key, (self.dim, int(self.lamb / 2))) - self.z = self.z.at[:, self.idxp].set(zhalf) - self.z = self.z.at[:, self.idxm].set(-zhalf) - self.normv = jnp.linalg.norm(self.v) - self.normv2 = self.normv ** 2 - self.vbar = self.v / self.normv - self.y = self.z + ((jnp.sqrt(1 + self.normv2) - 1) * jnp.dot(self.vbar, jnp.dot(self.vbar.T, self.z))) - self.x = self.m + (self.sigma * self.y) * self.D + self.x, self.y, self.z, self.vbar, self.normv, self.normv2, self.key = \ + self._generate_population(self.v, self.m, self.sigma, self.D, self.key) return self.x.T - def tell(self, evals_no_sort: np.ndarray) -> None: - sorted_indices = sort_indices_by(evals_no_sort, self.z) + def tell(self, evals_no_sort: jnp.ndarray) -> None: + sorted_indices = jnp.argsort(evals_no_sort) best_eval_id = sorted_indices[0] f_best = evals_no_sort[best_eval_id] x_best = self.x[:, best_eval_id] @@ -156,8 +163,7 @@ def tell(self, evals_no_sort: np.ndarray) -> None: self.g += 1 if f_best < self.f_best: self.f_best = f_best - self.x_best = x_best - + self.x_best = x_best # This operation assumes that if the solution is infeasible, infinity comes in as input. lambF = jnp.sum(evals_no_sort < jnp.finfo(float).max) # evolution path p_sigma @@ -224,16 +230,3 @@ def get_h_inv(dim: int) -> float: def exp(a: float) -> float: return math.exp(min(100, a)) # avoid overflow - -def sort_indices_by(evals: np.ndarray, z: jnp.ndarray) -> jnp.ndarray: - lam = len(evals) - sorted_indices = np.argsort(evals) - sorted_evals = evals[sorted_indices] - no_of_feasible_solutions = np.where(sorted_evals != jnp.inf)[0].size - if no_of_feasible_solutions != lam: - infeasible_z = z[:, np.where(evals == jnp.inf)[0]] - distances = np.sum(infeasible_z ** 2, axis=0) - infeasible_indices = sorted_indices[no_of_feasible_solutions:] - indices_sorted_by_distance = np.argsort(distances) - sorted_indices = sorted_indices.at[no_of_feasible_solutions:].set(infeasible_indices[indices_sorted_by_distance]) - return sorted_indices diff --git a/evojax/algo/diversifier.py b/evojax/algo/diversifier.py new file mode 100644 index 00000000..4d772a72 --- /dev/null +++ b/evojax/algo/diversifier.py @@ -0,0 +1,155 @@ +# Copyright 2022 The EvoJAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of the Diversifier algorithm in JAX. + +Diversifier is a JAX based meta-algorithm which is a generalization of CMA-ME +(see https://arxiv.org/pdf/1912.02400.pdf). It uses a MAP-Elites +archive not for solution candidate generation, but only to modify +the fitness values told (via tell) to the wrapped algorithm. +This modification changes the fitness ranking of the population +to favor exploration over exploitation. Tested with +CR-FM-NES and CMA-ES, but other wrapped algorithms may work as well. +Based on https://github.com/dietmarwo/fast-cma-es/blob/master/fcmaes/diversifier.py +(see https://github.com/dietmarwo/fast-cma-es/blob/master/tutorials/MapElites.adoc) + +For Brax Ant CR-FM-NES-ME (wrapping CR-FM-NES), compared with MAP-Elites, reaches +a higher QD-score (sum of fitness values of all elites in the map) for high +iteration numbers. Use MAP-Elites instead for a low evaluation budget or if you +want to maximize the number of occupied niches. +""" + +import logging +import numpy as np + +from typing import Union + +import jax +import jax.numpy as jnp + +from evojax.algo.base import QualityDiversityMethod, NEAlgorithm +from evojax.task.base import TaskState +from evojax.task.base import BDExtractor +from evojax.util import create_logger + +from time import time + +class Diversifier(QualityDiversityMethod): + """The Diversifier meta algorithm.""" + + def __init__(self, + solver: NEAlgorithm, + pop_size: int, + param_size: int, + bd_extractor: BDExtractor, + fitness_weight: float = 0, + seed: int = 0, + logger: logging.Logger = None): + """Initialization function. + + Args: + solver: wrapped solver, CR-FM-NES and CMA-ES work well. + pop_size - Population size. + param_size - Parameter size. + bd_extractor - A list of behavior descriptor extractors. + fitness_weight - factor applied to the original fitness. + Should be in interval [0,1]. Higher value means: + QD-score grows faster, but will stop growing earlier. + Choose fitness_weight = 1 if you are also interested in + a good global optimum. + seed - Random seed for parameters sampling. + logger - Logging utility. + """ + + if logger is None: + self._logger = create_logger(name=solver._logger.name + '-ME') + else: + self._logger = logger + + self.solver = solver + self.pop_size = abs(pop_size) + self.param_size = param_size + self.fitness_weight = fitness_weight + self.bd_names = [x[0] for x in bd_extractor.bd_spec] + self.bd_n_bins = [x[1] for x in bd_extractor.bd_spec] + self.params_lattice = jnp.zeros((np.prod(self.bd_n_bins), param_size)) + self.fitness_lattice = -float('Inf') * jnp.ones(np.prod(self.bd_n_bins)) + self.occupancy_lattice = jnp.zeros( + np.prod(self.bd_n_bins), dtype=jnp.int32) + self.population = None + self.bin_idx = jnp.zeros(self.pop_size, dtype=jnp.int32) + self.key = jax.random.PRNGKey(seed) + + def get_bin_idx(task_state): + bd_idx = [ + task_state.__dict__[name].astype(int) for name in self.bd_names] + return jnp.ravel_multi_index(bd_idx, self.bd_n_bins, mode='clip') + self._get_bin_idx = jax.jit(jax.vmap(get_bin_idx)) + + def update_fitness_and_param( + target_bin, bin_idx, + fitness, fitness_lattice, param, param_lattice): + best_ix = jnp.where( + bin_idx == target_bin, fitness, fitness_lattice.min()).argmax() + best_fitness = fitness[best_ix] + new_fitness_lattice = jnp.where( + best_fitness > fitness_lattice[target_bin], + best_fitness, fitness_lattice[target_bin]) + new_param_lattice = jnp.where( + best_fitness > fitness_lattice[target_bin], + param[best_ix], param_lattice[target_bin]) + return new_fitness_lattice, new_param_lattice + self._update_lattices = jax.jit(jax.vmap( + update_fitness_and_param, + in_axes=(0, None, None, None, None, None))) + + def get_to_tell( + fitness, lattice_fitness, fitness_weight): + improvement = fitness - lattice_fitness + max_valid = jnp.amax(jnp.where(improvement == np.inf, -np.inf, improvement)) + norm_fitness = fitness - jnp.amin(fitness) + 1E-9 + improvement = jnp.where(improvement == np.inf, max_valid + norm_fitness, improvement) + to_tell = jnp.where(fitness_weight <= 0, improvement, + improvement + fitness_weight * fitness) + return to_tell + self._get_to_tell = jax.jit(get_to_tell) + + def ask(self) -> jnp.ndarray: + self.population = self.solver.ask() + return self.population + + def observe_bd(self, task_state: TaskState) -> None: + self.bin_idx = self._get_bin_idx(task_state) + + def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None: + lattice_fitness = self.fitness_lattice[self.bin_idx] + to_tell = self._get_to_tell(fitness, lattice_fitness, self.fitness_weight) + # tell the wrapped solver the modified fitness including improvement relative to the lattice + self.solver.tell(to_tell) + # update lattice + unique_bins = jnp.unique(self.bin_idx) + fitness_lattice, params_lattice = self._update_lattices( + unique_bins, self.bin_idx, + fitness, self.fitness_lattice, + self.population, self.params_lattice) + self.occupancy_lattice = self.occupancy_lattice.at[unique_bins].set(1) + self.fitness_lattice = self.fitness_lattice.at[unique_bins].set( + fitness_lattice) + self.params_lattice = self.params_lattice.at[unique_bins].set( + params_lattice) + + @property + def best_params(self) -> jnp.ndarray: + ix = jnp.argmax(self.fitness_lattice, axis=0) + return self.params_lattice[ix] diff --git a/examples/train_ant_diversifier.py b/examples/train_ant_diversifier.py new file mode 100644 index 00000000..9e628ce8 --- /dev/null +++ b/examples/train_ant_diversifier.py @@ -0,0 +1,209 @@ +# Copyright 2022 The EvoJAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train an ant locomotion controller with CR-FM-NES-ME (Diversifier applied to CR-FM-NES). + +To define a different BD extractor, see task/brax_task.py for example. + +Example command: +python train_ant_diversifier.py --max-iter=5000 +python train_ant_diversifier.py --max-iter=5000 --save-gif # May cost some time. +""" + +import argparse +import os +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +from functools import partial + +from evojax import Trainer +from evojax.task.brax_task import BraxTask +from evojax.task.brax_task import AntBDExtractor +from evojax.policy import MLPPolicy +from evojax.algo import Diversifier, CRFMNES +from evojax import util + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--pop-size', type=int, default=512, help='population size.') + parser.add_argument( + '--num-tests', type=int, default=128, help='Number of test rollouts.') + parser.add_argument( + '--n-repeats', type=int, default=8, help='Training repetitions.') + parser.add_argument( + '--max-iter', type=int, default=3000, help='Max training iterations.') + parser.add_argument( + '--test-interval', type=int, default=50, help='Test interval.') + parser.add_argument( + '--log-interval', type=int, default=10, help='Logging interval.') + parser.add_argument( + '--seed', type=int, default=42, help='Random seed for training.') + parser.add_argument( + '--init-std', type=float, default=0.159, help='Initial std.') + parser.add_argument( + '--fitness-weight', type=float, default=0, help='Fitness weight.') + parser.add_argument( + '--gpu-id', type=str, help='GPU(s) to use.') + parser.add_argument( + '--debug', action='store_true', help='Debug mode.') + parser.add_argument( + '--save-gif', action='store_true', help='Save some GIFs.') + config, _ = parser.parse_known_args() + return config + + +def plot_figure(lattice, log_dir, title): + grid = lattice.reshape((10, 10, 10, 10)) + fig, axes = plt.subplots(10, 10, figsize=(8, 8)) + for i in range(10): + for j in range(10): + ax = axes[i][j] + ax.imshow(grid[i, j]) + ax.set_axis_off() + fig.suptitle(title, fontsize=20, fontweight='bold') + plt.savefig(os.path.join(log_dir, '{}.png'.format(title))) + + +def main(config): + log_dir = './log/ant_diversifier' + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + logger = util.create_logger( + name='AntDiversifier', log_dir=log_dir, debug=config.debug) + + logger.info('EvoJAX AntDiversifier Demo') + logger.info('=' * 30) + + bd_extractor = AntBDExtractor(logger=logger) + train_task = BraxTask( + env_name='ant', max_steps=500, bd_extractor=bd_extractor, test=False) + test_task = BraxTask( + env_name='ant', bd_extractor=bd_extractor, test=True) + policy = MLPPolicy( + input_dim=train_task.obs_shape[0], + hidden_dims=[32, 32, 32, 32], + output_dim=train_task.act_shape[0], + ) + + wrapped_solver = CRFMNES( + pop_size=config.pop_size, + param_size=policy.num_params, + init_stdev=config.init_std, + logger=logger, + seed=config.seed, + ) + + solver = Diversifier( + solver = wrapped_solver, + pop_size=config.pop_size, + param_size=policy.num_params, + bd_extractor=bd_extractor, + fitness_weight=config.fitness_weight, + seed=config.seed, + logger=logger, + ) + + # Train. + trainer = Trainer( + policy=policy, + solver=solver, + train_task=train_task, + test_task=test_task, + max_iter=config.max_iter, + log_interval=config.log_interval, + test_interval=config.test_interval, + n_repeats=config.n_repeats, + n_evaluations=config.num_tests, + seed=config.seed, + log_dir=log_dir, + logger=logger, + ) + trainer.run(demo_mode=False) + + # Visualize the results. + qd_file = os.path.join(log_dir, 'qd_lattices.npz') + with open(qd_file, 'rb') as f: + data = np.load(f) + params_lattice = data['params_lattice'] + fitness_lattice = data['fitness_lattice'] + occupancy_lattice = data['occupancy_lattice'] + plot_figure(occupancy_lattice, log_dir, 'occupancy') + plot_figure(fitness_lattice, log_dir, 'score') + + # Visualize the top policies. + if config.save_gif: + import jax + import jax.numpy as jnp + from brax import envs + from brax.io import image + + @partial(jax.jit, static_argnums=(1,)) + def get_qp(state, ix): + return jax.tree_map(lambda x: x[ix], state.qp) + + num_viz = 3 + idx = fitness_lattice.argsort()[-num_viz:] + bins = [np.unravel_index(ix, (10, 10, 10, 10)) for ix in idx] + logger.info( + 'Best {} policies: indices={}, bins={}'.format(num_viz, idx, bins)) + + policy_params = jnp.array(params_lattice[idx]) + task_reset_fn = jax.jit(test_task.reset) + policy_reset_fn = jax.jit(policy.reset) + step_fn = jax.jit(test_task.step) + act_fn = jax.jit(policy.get_actions) + + total_reward = jnp.zeros(num_viz) + valid_masks = jnp.ones(num_viz) + rollouts = {i: [] for i in range(num_viz)} + keys = jnp.repeat( + jax.random.PRNGKey(seed=42)[None, :], repeats=num_viz, axis=0) + task_state = task_reset_fn(key=keys) + policy_state = policy_reset_fn(task_state) + + for step in range(test_task.max_steps): + for i in range(num_viz): + rollouts[i].append(get_qp(task_state.state, i)) + act, policy_state = act_fn(task_state, policy_params, policy_state) + task_state, reward, done = step_fn(task_state, act) + total_reward = total_reward + reward * valid_masks + valid_masks = valid_masks * (1 - done) + logger.info('test_rewards={}'.format(total_reward)) + + logger.info('Start saving GIFs, this can take some time ...') + env_fn = envs.create_fn(env_name='ant', legacy_spring=True) + env = env_fn() + for i in range(num_viz): + qps = jax.tree_map(lambda x: np.array(x), rollouts[i]) + frames = [ + Image.fromarray( + image.render_array(env.sys, qp, 320, 240, None, None, 2)) + for qp in qps] + frames[0].save( + os.path.join(log_dir, 'bin_{}.gif'.format(bins[i])), + format='png', + append_images=frames[1:], + save_all=True, + duration=env.sys.config.dt * 1000, + loop=0) + + +if __name__ == '__main__': + configs = parse_args() + if configs.gpu_id is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = configs.gpu_id + main(configs)