Skip to content

Commit

Permalink
Apply black format to tests folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Nov 4, 2024
1 parent 362ed94 commit 540a0a7
Show file tree
Hide file tree
Showing 39 changed files with 576 additions and 354 deletions.
92 changes: 53 additions & 39 deletions tests/jax/test_jax_model_instantiators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,77 +11,91 @@
from skrl.utils.model_instantiators.jax import Shape, categorical_model, deterministic_model, gaussian_model


@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10))
@hypothesis.given(
observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10),
)
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
def test_categorical_model(capsys, observation_space_size, action_space_size, device):
observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size))
action_space = gym.spaces.Discrete(action_space_size)
# TODO: randomize all parameters
model = categorical_model(observation_space=observation_space,
action_space=action_space,
device=device,
unnormalized_log_prob=True,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None)
model = categorical_model(
observation_space=observation_space,
action_space=action_space,
device=device,
unnormalized_log_prob=True,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
)
model.init_state_dict("model")

with jax.default_device(model.device):
observations = jnp.ones((10, model.num_observations))
output = model.act({"states": observations})
assert output[0].shape == (10, 1)

@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10))

@hypothesis.given(
observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10),
)
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
def test_deterministic_model(capsys, observation_space_size, action_space_size, device):
observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size))
action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size))
# TODO: randomize all parameters
model = deterministic_model(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1)
model = deterministic_model(
observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1,
)
model.init_state_dict("model")

with jax.default_device(model.device):
observations = jnp.ones((10, model.num_observations))
output = model.act({"states": observations})
assert output[0].shape == (10, model.num_actions)

@hypothesis.given(observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10))

@hypothesis.given(
observation_space_size=st.integers(min_value=1, max_value=10),
action_space_size=st.integers(min_value=1, max_value=10),
)
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
def test_gaussian_model(capsys, observation_space_size, action_space_size, device):
observation_space = gym.spaces.Box(np.array([-1] * observation_space_size), np.array([1] * observation_space_size))
action_space = gym.spaces.Box(np.array([-1] * action_space_size), np.array([1] * action_space_size))
# TODO: randomize all parameters
model = gaussian_model(observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
initial_log_std=0,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1)
model = gaussian_model(
observation_space=observation_space,
action_space=action_space,
device=device,
clip_actions=False,
clip_log_std=True,
min_log_std=-20,
max_log_std=2,
initial_log_std=0,
input_shape=Shape.STATES,
hiddens=[256, 256],
hidden_activation=["relu", "relu"],
output_shape=Shape.ACTIONS,
output_activation=None,
output_scale=1,
)
model.init_state_dict("model")

with jax.default_device(model.device):
Expand Down
53 changes: 23 additions & 30 deletions tests/jax/test_jax_model_instantiators_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_get_activation_function(capsys):
assert activation is not None, f"{item} -> None"
exec(f"{activation}(x)", _globals, {})


def test_parse_input(capsys):
# check for Shape enum (compatibility with prior versions)
for input in [Shape.STATES, Shape.OBSERVATIONS, Shape.ACTIONS, Shape.STATES_ACTIONS]:
Expand All @@ -43,6 +44,7 @@ def test_parse_input(capsys):
output = _parse_input(str(input))
assert output.replace("'", '"') == statement, f"'{output}' != '{statement}'"


def test_generate_modules(capsys):
_globals = {"nn": flax.linen}

Expand Down Expand Up @@ -138,6 +140,7 @@ def test_generate_modules(capsys):
assert isinstance(container, flax.linen.Sequential)
assert len(container.layers) == 2


def test_gaussian_model(capsys):
device = "cpu"
observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5))
Expand All @@ -161,19 +164,15 @@ def test_gaussian_model(capsys):
"""
content = yaml.safe_load(content)
# source
model = gaussian_model(observation_space=observation_space,
action_space=action_space,
device=device,
return_source=True,
**content)
model = gaussian_model(
observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content
)
with capsys.disabled():
print(model)
# instance
model = gaussian_model(observation_space=observation_space,
action_space=action_space,
device=device,
return_source=False,
**content)
model = gaussian_model(
observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content
)
model.init_state_dict("model")
with capsys.disabled():
print(model)
Expand All @@ -182,6 +181,7 @@ def test_gaussian_model(capsys):
output = model.act({"states": observations})
assert output[0].shape == (10, 2)


def test_deterministic_model(capsys):
device = "cpu"
observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5))
Expand All @@ -202,19 +202,15 @@ def test_deterministic_model(capsys):
"""
content = yaml.safe_load(content)
# source
model = deterministic_model(observation_space=observation_space,
action_space=action_space,
device=device,
return_source=True,
**content)
model = deterministic_model(
observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content
)
with capsys.disabled():
print(model)
# instance
model = deterministic_model(observation_space=observation_space,
action_space=action_space,
device=device,
return_source=False,
**content)
model = deterministic_model(
observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content
)
model.init_state_dict("model")
with capsys.disabled():
print(model)
Expand All @@ -223,6 +219,7 @@ def test_deterministic_model(capsys):
output = model.act({"states": observations})
assert output[0].shape == (10, 3)


def test_categorical_model(capsys):
device = "cpu"
observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5))
Expand All @@ -242,19 +239,15 @@ def test_categorical_model(capsys):
"""
content = yaml.safe_load(content)
# source
model = categorical_model(observation_space=observation_space,
action_space=action_space,
device=device,
return_source=True,
**content)
model = categorical_model(
observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content
)
with capsys.disabled():
print(model)
# instance
model = categorical_model(observation_space=observation_space,
action_space=action_space,
device=device,
return_source=False,
**content)
model = categorical_model(
observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content
)
model.init_state_dict("model")
with capsys.disabled():
print(model)
Expand Down
13 changes: 11 additions & 2 deletions tests/jax/test_jax_utils_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
sample_space,
tensorize_space,
unflatten_tensorized_space,
untensorize_space
untensorize_space,
)

from ..stategies import gym_space_stategy, gymnasium_space_stategy
Expand All @@ -29,6 +29,7 @@ def _check_backend(x, backend):
else:
raise ValueError(f"Invalid backend type: {backend}")


def check_sampled_space(space, x, n, backend):
if isinstance(space, gymnasium.spaces.Box):
_check_backend(x, backend)
Expand Down Expand Up @@ -66,6 +67,7 @@ def occupied_size(s):
space_size = compute_space_size(space, occupied_size=True)
assert space_size == occupied_size(space)


@hypothesis.given(space=gymnasium_space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_tensorize_space(capsys, space: gymnasium.spaces.Space):
Expand Down Expand Up @@ -97,6 +99,7 @@ def check_tensorized_space(s, x, n):
tensorized_space = tensorize_space(space, sampled_space)
check_tensorized_space(space, tensorized_space, 5)


@hypothesis.given(space=gymnasium_space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_untensorize_space(capsys, space: gymnasium.spaces.Space):
Expand All @@ -108,7 +111,9 @@ def check_untensorized_space(s, x, squeeze_batch_dimension):
assert isinstance(x, (np.ndarray, int))
assert isinstance(x, int) if squeeze_batch_dimension else x.shape == (1, 1)
elif isinstance(s, gymnasium.spaces.MultiDiscrete):
assert isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape)
assert (
isinstance(x, np.ndarray) and x.shape == s.nvec.shape if squeeze_batch_dimension else (1, *s.nvec.shape)
)
elif isinstance(s, gymnasium.spaces.Dict):
list(map(check_untensorized_space, s.values(), x.values(), [squeeze_batch_dimension] * len(s)))
elif isinstance(s, gymnasium.spaces.Tuple):
Expand All @@ -124,6 +129,7 @@ def check_untensorized_space(s, x, squeeze_batch_dimension):
untensorized_space = untensorize_space(space, tensorized_space, squeeze_batch_dimension=True)
check_untensorized_space(space, untensorized_space, squeeze_batch_dimension=True)


@hypothesis.given(space=gymnasium_space_stategy(), batch_size=st.integers(min_value=1, max_value=10))
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int):
Expand All @@ -134,6 +140,7 @@ def test_sample_space(capsys, space: gymnasium.spaces.Space, batch_size: int):
sampled_space = sample_space(space, batch_size, backend="jax")
check_sampled_space(space, sampled_space, batch_size, backend="jax")


@hypothesis.given(space=gymnasium_space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space):
Expand All @@ -147,6 +154,7 @@ def test_flatten_tensorized_space(capsys, space: gymnasium.spaces.Space):
flattened_space = flatten_tensorized_space(tensorized_space)
assert flattened_space.shape == (5, space_size)


@hypothesis.given(space=gymnasium_space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space):
Expand All @@ -160,6 +168,7 @@ def test_unflatten_tensorized_space(capsys, space: gymnasium.spaces.Space):
unflattened_space = unflatten_tensorized_space(space, flattened_space)
check_sampled_space(space, unflattened_space, 5, backend="jax")


@hypothesis.given(space=gym_space_stategy())
@hypothesis.settings(suppress_health_check=[hypothesis.HealthCheck.function_scoped_fixture], deadline=None)
def test_convert_gym_space(capsys, space: gym.spaces.Space):
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_jax_wrapper_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_env(capsys: pytest.CaptureFixture, backend: str):

env.close()


@pytest.mark.parametrize("backend", ["jax", "numpy"])
@pytest.mark.parametrize("vectorization_mode", ["async", "sync"])
def test_vectorized_env(capsys: pytest.CaptureFixture, backend: str, vectorization_mode: str):
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_jax_wrapper_gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_env(capsys: pytest.CaptureFixture, backend: str):

env.close()


@pytest.mark.parametrize("backend", ["jax", "numpy"])
@pytest.mark.parametrize("vectorization_mode", ["async", "sync"])
def test_vectorized_env(capsys: pytest.CaptureFixture, backend: str, vectorization_mode: str):
Expand Down
2 changes: 1 addition & 1 deletion tests/jax/test_jax_wrapper_isaacgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, num_states) -> None:

self.state_space = gym.spaces.Box(np.ones(self.num_states) * -np.Inf, np.ones(self.num_states) * np.Inf)
self.observation_space = gym.spaces.Box(np.ones(self.num_obs) * -np.Inf, np.ones(self.num_obs) * np.Inf)
self.action_space = gym.spaces.Box(np.ones(self.num_actions) * -1., np.ones(self.num_actions) * 1.)
self.action_space = gym.spaces.Box(np.ones(self.num_actions) * -1.0, np.ones(self.num_actions) * 1.0)

def reset(self) -> Dict[str, torch.Tensor]:
obs_dict = {}
Expand Down
Loading

0 comments on commit 540a0a7

Please sign in to comment.