Skip to content

Commit

Permalink
Merge branch 'feat/default_tasks' of github.com:adaptive-intelligent-…
Browse files Browse the repository at this point in the history
…robotics/QDax into feat/default_tasks
  • Loading branch information
Lookatator committed Oct 11, 2022
2 parents f5b88fb + 169d1ac commit ed52ea6
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/default_tasks_test/arm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools

import jax
import jax.numpy as jnp
import pytest

from qdax.core.containers.mapelites_repertoire import compute_euclidean_centroids
Expand Down Expand Up @@ -105,5 +106,42 @@ def test_arm(task_name: str, batch_size: int) -> None:
pytest.assume(repertoire is not None)



def test_arm_scoring_function() -> None:

# Init a random key
seed = 42
random_key = jax.random.PRNGKey(seed)

# arm has xy BD centered at 0.5 0.5 and min max range is [0,1]
# 0 params of first genotype is horizontal and points towards negative x axis
# angles move in anticlockwise direction
genotypes_1 = jnp.ones(shape=(1, 4))*0.5 # 0.5
genotypes_2 = jnp.zeros(shape=(1, 6)) # zeros - this folds upon itself (if even number ends up at origin)
genotypes_3 = jnp.ones(shape=(1, 10)) # ones - this also folds upon itself (if even number ends up at origin)
genotypes_4 = jnp.array([[0, 0.5]])
genotypes_5 = jnp.array([[0.25, 0.5]])
genotypes_6 = jnp.array([[0.5, 0.5]])
genotypes_7 = jnp.array([[0.75, 0.5]])

fitness_1, descriptors_1, _, random_key = arm_scoring_function(genotypes_1, random_key)
fitness_2, descriptors_2, _, random_key = arm_scoring_function(genotypes_2, random_key)
fitness_3, descriptors_3, _, random_key = arm_scoring_function(genotypes_3, random_key)
fitness_4, descriptors_4, _, random_key = arm_scoring_function(genotypes_4, random_key)
fitness_5, descriptors_5, _, random_key = arm_scoring_function(genotypes_5, random_key)
fitness_6, descriptors_6, _, random_key = arm_scoring_function(genotypes_6, random_key)
fitness_7, descriptors_7, _, random_key = arm_scoring_function(genotypes_7, random_key)

# use rounding to avoid some numerical floating point errors
pytest.assume(jnp.array_equal(jnp.around(descriptors_1, decimals=1), jnp.array([[1.0, 0.5]])))
pytest.assume(jnp.array_equal(jnp.around(descriptors_2, decimals=1), jnp.array([[0.5, 0.5]])))
pytest.assume(jnp.array_equal(jnp.around(descriptors_3, decimals=1), jnp.array([[0.5, 0.5]])))
pytest.assume(jnp.array_equal(jnp.around(descriptors_4, decimals=1), jnp.array([[0.0, 0.5]])))
pytest.assume(jnp.array_equal(jnp.around(descriptors_5, decimals=1), jnp.array([[0.5, 0.0]])))
pytest.assume(jnp.array_equal(jnp.around(descriptors_6, decimals=1), jnp.array([[1.0, 0.5]])))
pytest.assume(jnp.array_equal(jnp.around(descriptors_7, decimals=1), jnp.array([[0.5, 1.0]])))


if __name__ == "__main__":
test_arm(task_name="arm", batch_size=128)
test_arm_scoring_function()

0 comments on commit ed52ea6

Please sign in to comment.