|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from typing import Dict, List, Optional, Union |
2 | 4 |
|
| 5 | +import gymnasium |
| 6 | +import gymnax |
3 | 7 | import jax.numpy as jnp |
4 | 8 | import numpy as np |
5 | 9 | from gymnax.environments.classic_control.acrobot import Acrobot |
6 | | -from gymnax.environments.spaces import Space, gymnax_space_to_gym_space |
7 | | -from gymnax.wrappers.gym import GymnaxToGymWrapper |
8 | 10 |
|
9 | 11 | from carl.context.selection import AbstractSelector |
10 | 12 | from carl.envs.carl_env import CARLEnv |
| 13 | +from carl.envs.gymnax.carl_gymnax_env import CARLGymnaxEnv |
11 | 14 | from carl.utils.trial_logger import TrialLogger |
12 | 15 | from carl.utils.types import Context, Contexts |
13 | 16 |
|
|
16 | 19 | "link_length_2": 1, |
17 | 20 | "link_mass_1": 1, |
18 | 21 | "link_mass_2": 1, |
19 | | - "link_com_1": 0.5, |
20 | | - "link_com_2": 0.5, |
| 22 | + "link_com_pos_1": 0.5, |
| 23 | + "link_com_pos_2": 0.5, |
21 | 24 | "link_moi": 1, |
22 | | - "max_velocity_1": 4 * jnp.pi, |
23 | | - "max_velocity_2": 9 * jnp.pi, |
| 25 | + "max_vel_1": 4 * jnp.pi, |
| 26 | + "max_vel_2": 9 * jnp.pi, |
| 27 | + "torque_noise_max": 0.0, |
| 28 | + "max_steps_in_episode": 500, |
24 | 29 | } |
25 | 30 |
|
26 | 31 | CONTEXT_BOUNDS = { |
|
36 | 41 | float, |
37 | 42 | ), # Link mass can be shrunken and grown by a factor of 10 |
38 | 43 | "link_mass_2": (0.1, 10, float), |
39 | | - "link_com_1": (0, 1, float), # Center of mass can move from one end to the other |
40 | | - "link_com_2": (0, 1, float), |
| 44 | + "link_com_pos_1": ( |
| 45 | + 0, |
| 46 | + 1, |
| 47 | + float, |
| 48 | + ), # Center of mass can move from one end to the other |
| 49 | + "link_com_pos_2": (0, 1, float), |
41 | 50 | "link_moi": ( |
42 | 51 | 0.1, |
43 | 52 | 10, |
44 | 53 | float, |
45 | 54 | ), # Moments on inertia can be shrunken and grown by a factor of 10 |
46 | | - "max_velocity_1": ( |
| 55 | + "max_vel_1": ( |
47 | 56 | 0.4 * np.pi, |
48 | 57 | 40 * np.pi, |
49 | 58 | float, |
50 | 59 | ), # Velocity can vary by a factor of 10 in either direction |
51 | | - "max_velocity_2": (0.9 * np.pi, 90 * np.pi, float), |
| 60 | + "max_vel_2": (0.9 * np.pi, 90 * np.pi, float), |
52 | 61 | "torque_noise_max": ( |
53 | 62 | -1.0, |
54 | 63 | 1.0, |
55 | 64 | float, |
56 | 65 | ), # torque is either {-1., 0., 1}. Applying noise of 1. would be quite extreme |
57 | | - "initial_angle_lower": (-jnp.inf, jnp.inf, float), |
58 | | - "initial_angle_upper": (-jnp.inf, jnp.inf, float), |
59 | | - "initial_velocity_lower": (-jnp.inf, jnp.inf, float), |
60 | | - "initial_velocity_upper": (-jnp.inf, jnp.inf, float), |
| 66 | + "max_steps_in_episode": (1, jnp.inf, int), |
61 | 67 | } |
62 | 68 |
|
63 | 69 |
|
64 | | -class CustomGymnaxToGymWrapper(GymnaxToGymWrapper): |
65 | | - @property |
66 | | - def observation_space(self) -> Dict: |
67 | | - return gymnax_space_to_gym_space(self._env.observation_space(self.env_params)) |
68 | | - |
69 | | - @observation_space.setter |
70 | | - def observation_space(self, value: Space) -> None: |
71 | | - self._observation_space = value |
72 | | - |
73 | | - |
74 | | -class CARLJaxAcrobotEnv(CARLEnv): |
75 | | - def __init__( |
76 | | - self, |
77 | | - env: Acrobot = CustomGymnaxToGymWrapper(Acrobot()), |
78 | | - contexts: Contexts = {}, |
79 | | - hide_context: bool = True, |
80 | | - add_gaussian_noise_to_context: bool = False, |
81 | | - gaussian_noise_std_percentage: float = 0.01, |
82 | | - logger: Optional[TrialLogger] = None, |
83 | | - scale_context_features: str = "no", |
84 | | - default_context: Optional[Context] = DEFAULT_CONTEXT, |
85 | | - max_episode_length: int = 500, # from https://github.com/openai/gym/blob/master/gym/envs/__init__.py |
86 | | - state_context_features: Optional[List[str]] = None, |
87 | | - context_mask: Optional[List[str]] = None, |
88 | | - dict_observation_space: bool = False, |
89 | | - context_selector: Optional[ |
90 | | - Union[AbstractSelector, type[AbstractSelector]] |
91 | | - ] = None, |
92 | | - context_selector_kwargs: Optional[Dict] = None, |
93 | | - ): |
94 | | - if not contexts: |
95 | | - contexts = {0: DEFAULT_CONTEXT} |
96 | | - super().__init__( |
97 | | - env=env, |
98 | | - contexts=contexts, |
99 | | - hide_context=hide_context, |
100 | | - add_gaussian_noise_to_context=add_gaussian_noise_to_context, |
101 | | - gaussian_noise_std_percentage=gaussian_noise_std_percentage, |
102 | | - logger=logger, |
103 | | - scale_context_features=scale_context_features, |
104 | | - default_context=default_context, |
105 | | - max_episode_length=max_episode_length, |
106 | | - state_context_features=state_context_features, |
107 | | - dict_observation_space=dict_observation_space, |
108 | | - context_selector=context_selector, |
109 | | - context_selector_kwargs=context_selector_kwargs, |
110 | | - context_mask=context_mask, |
111 | | - ) |
112 | | - self.whitelist_gaussian_noise = list( |
113 | | - DEFAULT_CONTEXT.keys() |
114 | | - ) # allow to augment all values |
| 70 | +class CARLJaxAcrobotEnv(CARLGymnaxEnv): |
| 71 | + env_name: str = "Acrobot-v1" |
| 72 | + max_episode_steps: int = DEFAULT_CONTEXT["max_steps_in_episode"] |
| 73 | + DEFAULT_CONTEXT: Context = DEFAULT_CONTEXT |
115 | 74 |
|
116 | 75 | def _update_context(self) -> None: |
117 | | - self.env: Acrobot |
118 | | - self.env.LINK_LENGTH_1 = self.context["link_length_1"] |
119 | | - self.env.LINK_LENGTH_2 = self.context["link_length_2"] |
120 | | - self.env.LINK_MASS_1 = self.context["link_mass_1"] |
121 | | - self.env.LINK_MASS_2 = self.context["link_mass_2"] |
122 | | - self.env.LINK_COM_POS_1 = self.context["link_com_1"] |
123 | | - self.env.LINK_COM_POS_2 = self.context["link_com_2"] |
124 | | - self.env.LINK_MOI = self.context["link_moi"] |
125 | | - self.env.MAX_VEL_1 = self.context["max_velocity_1"] |
126 | | - self.env.MAX_VEL_2 = self.context["max_velocity_2"] |
| 76 | + content = self.env.env.env_params.__dict__ |
| 77 | + content.update(self.context) |
| 78 | + # We cannot directly set attributes of env_params because it is a frozen dataclass |
| 79 | + self.env.env.env_params = gymnax.environments.classic_control.acrobot.EnvParams( |
| 80 | + **content |
| 81 | + ) |
127 | 82 |
|
128 | 83 | high = jnp.array( |
129 | | - [1.0, 1.0, 1.0, 1.0, self.env.MAX_VEL_1, self.env.MAX_VEL_2], |
| 84 | + [ |
| 85 | + 1.0, |
| 86 | + 1.0, |
| 87 | + 1.0, |
| 88 | + 1.0, |
| 89 | + self.env.env.env_params.max_vel_1, |
| 90 | + self.env.env.env_params.max_vel_2, |
| 91 | + ], |
130 | 92 | dtype=jnp.float32, |
131 | 93 | ) |
132 | 94 | low = -high |
|
0 commit comments