Skip to content

Commit 9de0e29

Browse files
author
Carolin Benjamins
committed
Fix gymnax
1 parent 92458e0 commit 9de0e29

File tree

9 files changed

+224
-427
lines changed

9 files changed

+224
-427
lines changed

carl/envs/gymnax/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@
1313
)
1414
from carl.envs.gymnax.carl_jax_cartpole import CARLJaxCartPoleEnv # noqa: F401
1515
from carl.envs.gymnax.carl_jax_mountaincar import ( # noqa: F401
16+
CONTEXT_BOUNDS as CARLJaxMountainCarContinuousEnv_bounds,
17+
)
18+
from carl.envs.gymnax.carl_jax_mountaincar import (
1619
CONTEXT_BOUNDS as CARLJaxMountainCarEnv_bounds,
1720
)
1821
from carl.envs.gymnax.carl_jax_mountaincar import ( # noqa: F401
22+
DEFAULT_CONTEXT as CARLJaxMountainCarContinuousEnv_defaults,
23+
)
24+
from carl.envs.gymnax.carl_jax_mountaincar import (
1925
DEFAULT_CONTEXT as CARLJaxMountainCarEnv_defaults,
2026
)
2127
from carl.envs.gymnax.carl_jax_mountaincar import CARLJaxMountainCarEnv # noqa: F401
22-
from carl.envs.gymnax.carl_jax_mountaincarcontinuous import ( # noqa: F401
23-
CONTEXT_BOUNDS as CARLJaxMountainCarContinuousEnv_bounds,
24-
)
25-
from carl.envs.gymnax.carl_jax_mountaincarcontinuous import ( # noqa: F401
26-
DEFAULT_CONTEXT as CARLJaxMountainCarContinuousEnv_defaults,
27-
)
28-
from carl.envs.gymnax.carl_jax_mountaincarcontinuous import ( # noqa: F401
28+
from carl.envs.gymnax.carl_jax_mountaincar import ( # noqa: F401
2929
CARLJaxMountainCarContinuousEnv,
3030
)
3131
from carl.envs.gymnax.carl_jax_pendulum import ( # noqa: F401
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
3+
import gymnasium
4+
5+
from carl.context.selection import AbstractSelector
6+
from carl.envs.carl_env import CARLEnv
7+
from carl.envs.gymnax.wrappers import make_gymnax_env
8+
from carl.utils.trial_logger import TrialLogger
9+
from carl.utils.types import Context, Contexts
10+
11+
12+
class CARLGymnaxEnv(CARLEnv):
13+
env_name: str
14+
DEFAULT_CONTEXT: Context
15+
max_episode_steps: int
16+
17+
def __init__(
18+
self,
19+
env: gymnasium.Env | None = None,
20+
contexts: Contexts = {},
21+
hide_context: bool = True,
22+
add_gaussian_noise_to_context: bool = False,
23+
gaussian_noise_std_percentage: float = 0.01,
24+
logger: Optional[TrialLogger] = None,
25+
scale_context_features: str = "no",
26+
default_context: Optional[Context] = None,
27+
state_context_features: Optional[List[str]] = None,
28+
context_mask: Optional[List[str]] = None,
29+
dict_observation_space: bool = False,
30+
context_selector: Optional[
31+
Union[AbstractSelector, type[AbstractSelector]]
32+
] = None,
33+
context_selector_kwargs: Optional[Dict] = None,
34+
):
35+
"""
36+
Max torque is not a context feature because it changes the action space.
37+
38+
Parameters
39+
----------
40+
env
41+
contexts
42+
instance_mode
43+
hide_context
44+
add_gaussian_noise_to_context
45+
gaussian_noise_std_percentage
46+
"""
47+
if env is None:
48+
env = make_gymnax_env(env_name=self.env_name)
49+
50+
if not contexts:
51+
contexts = {0: self.DEFAULT_CONTEXT}
52+
53+
if not default_context:
54+
default_context = self.DEFAULT_CONTEXT
55+
56+
super().__init__(
57+
env=env,
58+
contexts=contexts,
59+
hide_context=hide_context,
60+
add_gaussian_noise_to_context=add_gaussian_noise_to_context,
61+
gaussian_noise_std_percentage=gaussian_noise_std_percentage,
62+
logger=logger,
63+
scale_context_features=scale_context_features,
64+
default_context=default_context,
65+
max_episode_length=self.max_episode_steps,
66+
state_context_features=state_context_features,
67+
dict_observation_space=dict_observation_space,
68+
context_selector=context_selector,
69+
context_selector_kwargs=context_selector_kwargs,
70+
context_mask=context_mask,
71+
)
72+
self.whitelist_gaussian_noise = list(
73+
self.DEFAULT_CONTEXT.keys()
74+
) # allow to augment all values
75+
76+
def _update_context(self) -> None:
77+
raise NotImplementedError
78+
79+
def __getattr__(self, name: str) -> Any:
80+
if name in ["sys", "__getstate__"]:
81+
return getattr(self.env._environment, name)
82+
else:
83+
return getattr(self, name)

carl/envs/gymnax/carl_jax_acrobot.py

Lines changed: 38 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from __future__ import annotations
2+
13
from typing import Dict, List, Optional, Union
24

5+
import gymnasium
6+
import gymnax
37
import jax.numpy as jnp
48
import numpy as np
59
from gymnax.environments.classic_control.acrobot import Acrobot
6-
from gymnax.environments.spaces import Space, gymnax_space_to_gym_space
7-
from gymnax.wrappers.gym import GymnaxToGymWrapper
810

911
from carl.context.selection import AbstractSelector
1012
from carl.envs.carl_env import CARLEnv
13+
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv
1114
from carl.utils.trial_logger import TrialLogger
1215
from carl.utils.types import Context, Contexts
1316

@@ -16,11 +19,13 @@
1619
"link_length_2": 1,
1720
"link_mass_1": 1,
1821
"link_mass_2": 1,
19-
"link_com_1": 0.5,
20-
"link_com_2": 0.5,
22+
"link_com_pos_1": 0.5,
23+
"link_com_pos_2": 0.5,
2124
"link_moi": 1,
22-
"max_velocity_1": 4 * jnp.pi,
23-
"max_velocity_2": 9 * jnp.pi,
25+
"max_vel_1": 4 * jnp.pi,
26+
"max_vel_2": 9 * jnp.pi,
27+
"torque_noise_max": 0.0,
28+
"max_steps_in_episode": 500,
2429
}
2530

2631
CONTEXT_BOUNDS = {
@@ -36,97 +41,54 @@
3641
float,
3742
), # Link mass can be shrunken and grown by a factor of 10
3843
"link_mass_2": (0.1, 10, float),
39-
"link_com_1": (0, 1, float), # Center of mass can move from one end to the other
40-
"link_com_2": (0, 1, float),
44+
"link_com_pos_1": (
45+
0,
46+
1,
47+
float,
48+
), # Center of mass can move from one end to the other
49+
"link_com_pos_2": (0, 1, float),
4150
"link_moi": (
4251
0.1,
4352
10,
4453
float,
4554
), # Moments on inertia can be shrunken and grown by a factor of 10
46-
"max_velocity_1": (
55+
"max_vel_1": (
4756
0.4 * np.pi,
4857
40 * np.pi,
4958
float,
5059
), # Velocity can vary by a factor of 10 in either direction
51-
"max_velocity_2": (0.9 * np.pi, 90 * np.pi, float),
60+
"max_vel_2": (0.9 * np.pi, 90 * np.pi, float),
5261
"torque_noise_max": (
5362
-1.0,
5463
1.0,
5564
float,
5665
), # torque is either {-1., 0., 1}. Applying noise of 1. would be quite extreme
57-
"initial_angle_lower": (-jnp.inf, jnp.inf, float),
58-
"initial_angle_upper": (-jnp.inf, jnp.inf, float),
59-
"initial_velocity_lower": (-jnp.inf, jnp.inf, float),
60-
"initial_velocity_upper": (-jnp.inf, jnp.inf, float),
66+
"max_steps_in_episode": (1, jnp.inf, int),
6167
}
6268

6369

64-
class CustomGymnaxToGymWrapper(GymnaxToGymWrapper):
65-
@property
66-
def observation_space(self) -> Dict:
67-
return gymnax_space_to_gym_space(self._env.observation_space(self.env_params))
68-
69-
@observation_space.setter
70-
def observation_space(self, value: Space) -> None:
71-
self._observation_space = value
72-
73-
74-
class CARLJaxAcrobotEnv(CARLEnv):
75-
def __init__(
76-
self,
77-
env: Acrobot = CustomGymnaxToGymWrapper(Acrobot()),
78-
contexts: Contexts = {},
79-
hide_context: bool = True,
80-
add_gaussian_noise_to_context: bool = False,
81-
gaussian_noise_std_percentage: float = 0.01,
82-
logger: Optional[TrialLogger] = None,
83-
scale_context_features: str = "no",
84-
default_context: Optional[Context] = DEFAULT_CONTEXT,
85-
max_episode_length: int = 500, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py
86-
state_context_features: Optional[List[str]] = None,
87-
context_mask: Optional[List[str]] = None,
88-
dict_observation_space: bool = False,
89-
context_selector: Optional[
90-
Union[AbstractSelector, type[AbstractSelector]]
91-
] = None,
92-
context_selector_kwargs: Optional[Dict] = None,
93-
):
94-
if not contexts:
95-
contexts = {0: DEFAULT_CONTEXT}
96-
super().__init__(
97-
env=env,
98-
contexts=contexts,
99-
hide_context=hide_context,
100-
add_gaussian_noise_to_context=add_gaussian_noise_to_context,
101-
gaussian_noise_std_percentage=gaussian_noise_std_percentage,
102-
logger=logger,
103-
scale_context_features=scale_context_features,
104-
default_context=default_context,
105-
max_episode_length=max_episode_length,
106-
state_context_features=state_context_features,
107-
dict_observation_space=dict_observation_space,
108-
context_selector=context_selector,
109-
context_selector_kwargs=context_selector_kwargs,
110-
context_mask=context_mask,
111-
)
112-
self.whitelist_gaussian_noise = list(
113-
DEFAULT_CONTEXT.keys()
114-
) # allow to augment all values
70+
class CARLJaxAcrobotEnv(CARLGymnaxEnv):
71+
env_name: str = "Acrobot-v1"
72+
max_episode_steps: int = DEFAULT_CONTEXT["max_steps_in_episode"]
73+
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT
11574

11675
def _update_context(self) -> None:
117-
self.env: Acrobot
118-
self.env.LINK_LENGTH_1 = self.context["link_length_1"]
119-
self.env.LINK_LENGTH_2 = self.context["link_length_2"]
120-
self.env.LINK_MASS_1 = self.context["link_mass_1"]
121-
self.env.LINK_MASS_2 = self.context["link_mass_2"]
122-
self.env.LINK_COM_POS_1 = self.context["link_com_1"]
123-
self.env.LINK_COM_POS_2 = self.context["link_com_2"]
124-
self.env.LINK_MOI = self.context["link_moi"]
125-
self.env.MAX_VEL_1 = self.context["max_velocity_1"]
126-
self.env.MAX_VEL_2 = self.context["max_velocity_2"]
76+
content = self.env.env.env_params.__dict__
77+
content.update(self.context)
78+
# We cannot directly set attributes of env_params because it is a frozen dataclass
79+
self.env.env.env_params = gymnax.environments.classic_control.acrobot.EnvParams(
80+
**content
81+
)
12782

12883
high = jnp.array(
129-
[1.0, 1.0, 1.0, 1.0, self.env.MAX_VEL_1, self.env.MAX_VEL_2],
84+
[
85+
1.0,
86+
1.0,
87+
1.0,
88+
1.0,
89+
self.env.env.env_params.max_vel_1,
90+
self.env.env.env_params.max_vel_2,
91+
],
13092
dtype=jnp.float32,
13193
)
13294
low = -high

carl/envs/gymnax/carl_jax_cartpole.py

Lines changed: 22 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
from typing import Dict, List, Optional, Union
1+
from __future__ import annotations
22

3+
import gymnax
34
import jax.numpy as jnp
45
from gymnax.environments.classic_control.cartpole import CartPole
5-
from gymnax.environments.spaces import Space, gymnax_space_to_gym_space
6-
from gymnax.wrappers.gym import GymnaxToGymWrapper
76

8-
from carl.context.selection import AbstractSelector
9-
from carl.envs.carl_env import CARLEnv
10-
from carl.utils.trial_logger import TrialLogger
7+
from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv
118
from carl.utils.types import Context, Contexts
129

1310
DEFAULT_CONTEXT = {
@@ -17,6 +14,9 @@
1714
"length": 0.5,
1815
"force_mag": 10.0,
1916
"tau": 0.02,
17+
"polemass_length": None,
18+
"total_mass": None,
19+
"max_steps_in_episode": 500,
2020
}
2121

2222
CONTEXT_BOUNDS = {
@@ -26,75 +26,32 @@
2626
"length": (0.25, 1.0, float),
2727
"force_mag": (5.0, 15.0, float),
2828
"tau": (0.01, 0.05, float),
29+
"polemass_length": (0, jnp.inf, float),
30+
"total_mass": (0, jnp.inf, float),
31+
"max_steps_in_episode": (1, jnp.inf, int),
2932
}
3033

3134

32-
class CustomGymnaxToGymWrapper(GymnaxToGymWrapper):
33-
@property
34-
def observation_space(self) -> Dict:
35-
return gymnax_space_to_gym_space(self._env.observation_space(self.env_params))
35+
class CARLJaxCartPoleEnv(CARLGymnaxEnv):
36+
env_name: str = "CartPole-v1"
37+
max_episode_steps: int = DEFAULT_CONTEXT["max_steps_in_episode"]
38+
DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT
3639

37-
@observation_space.setter
38-
def observation_space(self, value: Space) -> None:
39-
self._observation_space = value
40-
41-
42-
class CARLJaxCartPoleEnv(CARLEnv):
43-
def __init__(
44-
self,
45-
env: CartPole = CustomGymnaxToGymWrapper(CartPole()),
46-
contexts: Contexts = {},
47-
hide_context: bool = True,
48-
add_gaussian_noise_to_context: bool = False,
49-
gaussian_noise_std_percentage: float = 0.01,
50-
logger: Optional[TrialLogger] = None,
51-
scale_context_features: str = "no",
52-
default_context: Optional[Context] = DEFAULT_CONTEXT,
53-
max_episode_length: int = 500, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py
54-
state_context_features: Optional[List[str]] = None,
55-
context_mask: Optional[List[str]] = None,
56-
dict_observation_space: bool = False,
57-
context_selector: Optional[
58-
Union[AbstractSelector, type[AbstractSelector]]
59-
] = None,
60-
context_selector_kwargs: Optional[Dict] = None,
61-
):
62-
if not contexts:
63-
contexts = {0: DEFAULT_CONTEXT}
64-
super().__init__(
65-
env=env,
66-
contexts=contexts,
67-
hide_context=hide_context,
68-
add_gaussian_noise_to_context=add_gaussian_noise_to_context,
69-
gaussian_noise_std_percentage=gaussian_noise_std_percentage,
70-
logger=logger,
71-
scale_context_features=scale_context_features,
72-
default_context=default_context,
73-
max_episode_length=max_episode_length,
74-
state_context_features=state_context_features,
75-
dict_observation_space=dict_observation_space,
76-
context_selector=context_selector,
77-
context_selector_kwargs=context_selector_kwargs,
78-
context_mask=context_mask,
40+
def _update_context(self) -> None:
41+
self.context["polemass_length"] = (
42+
self.context["masspole"] * self.context["length"]
7943
)
80-
self.whitelist_gaussian_noise = list(
81-
DEFAULT_CONTEXT.keys()
82-
) # allow to augment all values
44+
self.context["total_mass"] = self.context["masscart"] + self.context["masspole"]
8345

84-
def _update_context(self) -> None:
85-
self.env: CartPole
86-
self.env.gravity = self.context["gravity"]
87-
self.env.masscart = self.context["masscart"]
88-
self.env.masspole = self.context["masspole"]
89-
self.env.length = self.context["length"]
90-
self.env.force_mag = self.context["force_mag"]
91-
self.env.tau = self.context["tau"]
46+
self.env.env.env_params = (
47+
gymnax.environments.classic_control.cartpole.EnvParams(**self.context)
48+
)
9249

9350
high = jnp.array(
9451
[
95-
self.env.x_threshold * 2,
52+
self.env.env.env_params.x_threshold * 2,
9653
jnp.finfo(jnp.float32).max,
97-
self.env.theta_threshold_radians * 2,
54+
self.env.env.env_params.theta_threshold_radians * 2,
9855
jnp.finfo(jnp.float32).max,
9956
],
10057
dtype=jnp.float32,

0 commit comments

Comments
 (0)