diff --git a/evojax/algo/__init__.py b/evojax/algo/__init__.py index 8942d8f1..69c0781e 100644 --- a/evojax/algo/__init__.py +++ b/evojax/algo/__init__.py @@ -15,6 +15,7 @@ from .base import NEAlgorithm from .base import QualityDiversityMethod from .cma_wrapper import CMA +from .cosyne import CoSyNE from .pgpe import PGPE from .ars import ARS from .simple_ga import SimpleGA @@ -30,6 +31,7 @@ from .fpgpec import FPGPEC Strategies = { + "CoSyNE": CoSyNE, "CMA": CMA, "PGPE": PGPE, "SimpleGA": SimpleGA, diff --git a/evojax/algo/cosyne.py b/evojax/algo/cosyne.py new file mode 100644 index 00000000..5bb73669 --- /dev/null +++ b/evojax/algo/cosyne.py @@ -0,0 +1,206 @@ +import logging +import numpy as np +from typing import Union +from typing import Tuple + +import jax +import jax.numpy as jnp + +from evojax.algo.base import NEAlgorithm +from evojax.util import create_logger + + +class CoSyNE(NEAlgorithm): + """Cooperative Synapse Neuroevoluiton (CoSyNE) algorithm. + + Attempts to reproduce CoSyNE as described by Gomes et al. in: + https://people.idsia.ch/~juergen/gomez08a.pdf + + TLDR: Cull bottom half of population, and replace with Cauchy-mutated multi-point-crossover offspring of the top + quarter of the population, and then shuffle the survivors' columns + """ + + def __init__(self, + pop_size: int, + param_size: int, + alpha: float = 2, + prob_mutate: float = 0.3, + segment_size: int = 411, + reset_prob: float = 0.0001, + seed: int = 0, + shuffle: bool = True, + logger: logging.Logger = create_logger(name='CoSyNE')): + """Initialization function. + + Args: + param_size - Parameter size. + pop_size - Population size. + alpha - Learning rate multiplier of cauchy distribution of mutation. + prob_mutate - Probability of mutation. + segment_size - Parameter segment size, each to include a single random crossover point + seed - Random seed for parameters sampling. + logger - Logger + """ + + self.logger = logger + self.param_size = param_size + self.n = self.param_size + self.prob_mutate = prob_mutate + assert pop_size % 4 == 0, 'Population size should be a multiple of 4, set to {}'.format(pop_size) + self.pop_size = abs(pop_size) + self.m = self.pop_size + self.reset_prob = reset_prob + + self.segment_size = segment_size + + self.alpha = alpha + + next_key, sample_key = jax.random.split(jax.random.PRNGKey(seed=seed)) + self.params = jax.random.uniform(sample_key, (pop_size, param_size), minval=-alpha, maxval=alpha) + self._best_params = None + + self.rand_key = jax.random.PRNGKey(seed=seed) + + self.jnp_array = jax.jit(jnp.array) + + def ask_fn(key: jnp.ndarray, + params: Union[np.ndarray, + jnp.ndarray]) -> Tuple[jnp.ndarray, + Union[np.ndarray, + jnp.ndarray]]: + # determine survivors + survivors = params[0:self.pop_size // 2] + + # determine parents + parents = params[0:self.pop_size // 4] + + # determine offspring + + # determine offspring: crossover once per segment at a random position + num_segments = self.n // self.segment_size + quarter_pop = self.m // 4 + next_key, sample_key = jax.random.split(key=key, num=2) + crossovers = jax.random.randint(key=sample_key, + shape=(quarter_pop, num_segments), + minval=0, + maxval=self.segment_size) + + # determine offspring: mates are uniformly chosen from fitter ranks + def for_mate_ranks(parent_index_and_next_key, _): + parent_rank = parent_index_and_next_key[0] + next_key, sample_key = jax.random.split(parent_index_and_next_key[1], 2) + mate_rank = jax.random.randint(key=sample_key, + shape=(1,), + minval=0, + maxval=parent_rank)[0] + + return (parent_index_and_next_key[0] + 1, next_key), mate_rank + + next_key, sample_key = jax.random.split(key=key, num=2) + parent_index_and_next_key, mate_ranks = jax.lax.scan(for_mate_ranks, (0, next_key), parents) + + next_key, sample_key = jax.random.split(key=key, num=2) + reset_chances = jax.random.uniform(key=sample_key, shape=(self.m // 2, param_size)) + + def offspring_elem(i, j): + """ + Creates offspring matrix element (i, j) as being either a copy of a parent or their mate, conditional on + whether they are offspring A or offspring B, and whether the genomes are currently crossed + """ + i = i.astype(int) + j = j.astype(int) + is_offspring_a = i < (self.m // 4) + is_offspring_b = i >= (self.m // 4) + parent_rank = i % (self.m // 4) + mate_rank = mate_ranks[parent_rank] + crossover_file = j // self.segment_size + crossover_pos = crossovers[parent_rank, crossover_file] + segment_pos = j % self.segment_size + crossed = jnp.any(jnp.array([jnp.all(jnp.array([crossover_file % 2 == 1, segment_pos < crossover_pos])), + jnp.all( + jnp.array([crossover_file % 2 == 0, segment_pos >= crossover_pos]))])) + + conds = jnp.array([ + reset_chances[i, j] > (1 - self.reset_prob), + jnp.all(jnp.array([is_offspring_a, crossed])), + jnp.all(jnp.array([is_offspring_b, crossed])), + is_offspring_a, + is_offspring_b + ]) + + index = jnp.argmax(conds) + branches = [ + lambda _x, _y: 0.0, + lambda x, y: parents[mate_rank, y], + lambda x, y: parents[parent_rank, y], + lambda x, y: parents[parent_rank, y], + lambda x, y: survivors[mate_rank, y] + ] + + return jax.lax.switch(index, branches, i, j) + + clones = jnp.fromfunction(offspring_elem, (self.m // 2, self.param_size)) + + # prob_mutate elements of unmutated_offspring should be mutated... + next_key, sample_key = jax.random.split(next_key, 2) + mutate = jax.random.choice(key=sample_key, + a=jnp.array([0, 1]), + p=jnp.array([1 - self.prob_mutate, self.prob_mutate]), + shape=(self.pop_size // 2, self.param_size)) + + # ...by a Cauchy distribution sample amount multiplied by an alpha learning rate + # NB: alpha can be very task specific (eg. 0.0001 for MNIST and 0.3 for cartpole_easy) + next_key, sample_key = jax.random.split(next_key, 2) + mutation = jax.random.cauchy(key=sample_key, + shape=(self.m // 2, + self.param_size)) * alpha + + offspring = clones + mutation * mutate + + if shuffle: + def shuffle_scan_body(carry, x): + next_key, sample_key = jax.random.split(carry[1], 2) + y = jax.random.permutation(key=sample_key, x=x) + return (carry[0] + 1, next_key), y + + next_key, sample_key = jax.random.split(next_key, 2) + carry, shuffled = jax.lax.scan(shuffle_scan_body, (0, next_key), jnp.transpose(survivors)) + _, next_key = carry + + new_params = jnp.vstack(( + jnp.transpose(shuffled), + offspring + )) + else: + new_params = jnp.vstack(( + survivors, + offspring + )) + + return next_key, new_params + + self.ask_fn = jax.jit(ask_fn) + + def tell_fn(fitness: Union[np.ndarray, jnp.ndarray], + params: Union[np.ndarray, jnp.ndarray]) -> Union[np.ndarray, jnp.ndarray]: + # sort population in fitness descending order + return params[(-fitness).argsort(axis=0)] + + self.tell_fn = jax.jit(tell_fn) + + def ask(self) -> jnp.ndarray: + self.rand_key, self.params = self.ask_fn(self.rand_key, self.params) + return self.params + + def tell(self, fitness: Union[np.ndarray, jnp.ndarray]) -> None: + self.params = self.tell_fn(fitness, self.params) + self._best_params = self.params[0] + + @property + def best_params(self) -> jnp.ndarray: + return self.jnp_array(self._best_params) + + @best_params.setter + def best_params(self, params: Union[np.ndarray, jnp.ndarray]) -> None: + self.params = jnp.repeat(params[None, :], self.pop_size, axis=0) + self._best_params = jnp.array(params, copy=True) diff --git a/scripts/benchmarks/Readme.md b/scripts/benchmarks/Readme.md index 08ae9bcc..85a88a07 100644 --- a/scripts/benchmarks/Readme.md +++ b/scripts/benchmarks/Readme.md @@ -105,6 +105,19 @@ Brax Ant | 3000 (max_iter=1200) |[Link](configs/PGPE/brax_ant.yaml)| 6054.3887 | Waterworld | 6 (max_iter=1000) | [Link](configs/PGPE/waterworld.yaml)| 11.6400 | Waterworld (MA) | 2 (max_iter=2000) | [Link](configs/PGPE/waterworld_ma.yaml) | 2.0625 | + +### CoSyNE + +| | Benchmarks | Parameters | Results (Avg) | +|------------------|-----------------------|-------------------------------------------|--------------| +| CartPole (easy) | 900 (max_iter=1000) | [Link](configs/CoSyNE/cartpole_easy.yaml) | 935.5311 | +| CartPole (hard) | 600 (max_iter=1000) | [Link](configs/CoSyNE/cartpole_hard.yaml) | 687.4745 | +| MNIST | 0.90 (max_iter=2000) | [Link](configs/CoSyNE/mnist.yaml) | 0.9522 | +| Waterworld | 6 (max_iter=1000) | [Link](configs/CoSyNE/waterworld.yaml) | 7.54 | +| Brax Ant | 3000 (max_iter=1200) | - | - | +| Waterworld (MA) | 2 (max_iter=2000) | - | - | + + ### CMA-ES-JAX | | Benchmarks | Parameters | Results (Avg) | diff --git a/scripts/benchmarks/configs/CoSyNE/cartpole_easy.yaml b/scripts/benchmarks/configs/CoSyNE/cartpole_easy.yaml new file mode 100644 index 00000000..44bf137b --- /dev/null +++ b/scripts/benchmarks/configs/CoSyNE/cartpole_easy.yaml @@ -0,0 +1,17 @@ +es_name: "CoSyNE" +problem_type: "cartpole_easy" +normalize: false +es_config: + pop_size: 64 + alpha: 0.3 + prob_mutate: 1.0 + segment_size: 411 +hidden_size: 64 +num_tests: 100 +n_repeats: 16 +max_iter: 1000 +test_interval: 100 +log_interval: 20 +seed: 42 +gpu_id: [0, 1, 2, 3] +debug: false \ No newline at end of file diff --git a/scripts/benchmarks/configs/CoSyNE/cartpole_hard.yaml b/scripts/benchmarks/configs/CoSyNE/cartpole_hard.yaml new file mode 100644 index 00000000..0ead0983 --- /dev/null +++ b/scripts/benchmarks/configs/CoSyNE/cartpole_hard.yaml @@ -0,0 +1,17 @@ +es_name: "CoSyNE" +problem_type: "cartpole_hard" +normalize: false +es_config: + pop_size: 64 + alpha: 0.3 + prob_mutate: 1.0 + segment_size: 411 +hidden_size: 64 +num_tests: 100 +n_repeats: 16 +max_iter: 1000 +test_interval: 100 +log_interval: 20 +seed: 42 +gpu_id: [0, 1, 2, 3] +debug: false \ No newline at end of file diff --git a/scripts/benchmarks/configs/CoSyNE/mnist.yaml b/scripts/benchmarks/configs/CoSyNE/mnist.yaml new file mode 100644 index 00000000..07dc6c97 --- /dev/null +++ b/scripts/benchmarks/configs/CoSyNE/mnist.yaml @@ -0,0 +1,18 @@ +es_name: "CoSyNE" +problem_type: "mnist" +normalize: false +es_config: + pop_size: 64 + alpha: 0.000085 + prob_mutate: 1.0 + segment_size: 1879 +hidden_size: 50 +batch_size: 1024 +max_iter: 2000 +test_interval: 500 +log_interval: 100 +num_tests: 1 +n_repeats: 1 +seed: 42 +gpu_id: [0, 1, 2, 3] +debug: false \ No newline at end of file diff --git a/scripts/benchmarks/configs/CoSyNE/waterworld.yaml b/scripts/benchmarks/configs/CoSyNE/waterworld.yaml new file mode 100644 index 00000000..96609d5e --- /dev/null +++ b/scripts/benchmarks/configs/CoSyNE/waterworld.yaml @@ -0,0 +1,17 @@ +es_name: "CoSyNE" +problem_type: "waterworld" +normalize: false +es_config: + pop_size: 64 + alpha: 0.3 + prob_mutate: 1.0 + segment_size: 411 +hidden_size: 100 +num_tests: 100 +n_repeats: 32 +max_iter: 1000 +test_interval: 50 +log_interval: 10 +seed: 42 +gpu_id: [0, 1, 2, 3] +debug: false \ No newline at end of file