Skip to content

Commit

Permalink
Added an entry for ActuatorDynamic to the FormatAction wrapper (#1836)
Browse files Browse the repository at this point in the history
* Added an entry for ActuatorDynamic to the FormatAction wrapper in smarts/env/wrappers/format_action.py.

* Update smarts/env/wrappers/format_action.py

* Update smarts/env/wrappers/format_action.py
  • Loading branch information
ajlangley authored Feb 7, 2023
1 parent 780d5f8 commit b0d6cdb
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions smarts/env/wrappers/format_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ class FormatAction(gym.ActionWrapper):
Note:
(a) Only ``ActionSpaceType.Continuous``, ``ActionSpaceType.Lane``, and
``ActionSpaceType.TargetPose`` are supported by this wrapper now.
(a) Only ``ActionSpaceType.Continuous``, ``ActionSpaceType.Lane``,
``ActionSpaceType.ActuatorDynamic``, and `ActionSpaceType.TargetPose``
are supported by this wrapper now.
(b) All agents should have the same action space.
"""
Expand All @@ -50,6 +51,7 @@ def __init__(self, env: gym.Env, space: ActionSpaceType):
space_map = {
"Continuous": _continuous,
"Lane": _lane,
"ActuatorDynamic": _actuator_dynamic,
"TargetPose": _target_pose,
}
self._wrapper, action_space = space_map.get(space.name)()
Expand Down Expand Up @@ -96,6 +98,17 @@ def wrapper(action: Dict[str, int]) -> Dict[str, str]:
return wrapper, space


def _actuator_dynamic() -> Tuple[Callable[[Dict[str, np.ndarray]], Dict[str, np.ndarray]], gym.Space]:
space = gym.spaces.Box(
low=np.array([0.0, 0.0, -1.0]), high=np.array([1.0, 1.0, 1.0]), dtype=np.float32
)

def wrapper(action: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return {k: v.astype(np.float32) for k, v in action.items()}

return wrapper, space


def _target_pose() -> Tuple[
Callable[[Dict[str, np.ndarray]], Dict[str, np.ndarray]], gym.Space
]:
Expand All @@ -108,4 +121,4 @@ def _target_pose() -> Tuple[
def wrapper(action: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return {k: v.astype(np.float32) for k, v in action.items()}

return wrapper, space
return wrapper, space

0 comments on commit b0d6cdb

Please sign in to comment.