Skip to content

Commit

Permalink
[gym_jiminy/common] Add generic 'DeltaQuantity' and 'DriftTrackingQua…
Browse files Browse the repository at this point in the history
…ntity'. (#874)

* [gym_jiminy/common] Argument 'horizon' of 'BaseJiminyEnv.evaluate' is now in seconds.
* [gym_jiminy/common] 'TrackingBaseHeightReward' now relies on relative base height rather than absolute.
* [gym_jiminy/common] Rename 'ReferencePositionWithTrueOdometryPose' in 'ReferencePositionVector' and add option to disable odometry pose override.
* [gym_jiminy/common] Add generic 'DeltaQuantity'.
* [gym_jiminy/common] Add 'DeltaBaseOdometryPosition' and 'DeltaBaseOdometryOrientation' quantities.
* [gym_jiminy/common] Add 'DriftTrackingBaseOdometryPoseReward'.
* [gym_jiminy/common] Inline 'InterfaceQuantity.refresh' if possible for efficiency.
  • Loading branch information
duburcqa authored Jan 29, 2025
1 parent 49240ff commit aee83a2
Show file tree
Hide file tree
Showing 12 changed files with 512 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ def replay(self, **kwargs: Any) -> None:
def evaluate(self,
policy_fn: PolicyCallbackFun,
seed: Optional[int] = None,
horizon: Optional[int] = None,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
Expand All @@ -465,9 +465,9 @@ def evaluate(self,
the policy.
Optional: `None` by default. If not specified, then a
strongly random seed will be generated by gym.
:param horizon: Horizon of the simulation, namely maximum number of
env steps before termination. `None` to disable.
Optional: Disabled by default.
:param horizon: Horizon of the simulation before early termination.
`None` to disable.
Optional: `None` by default.
:param enable_stats: Whether to print high-level statistics after the
simulation.
Optional: Enabled by default.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def replay(self, **kwargs: Any) -> None:
def evaluate(self,
policy_fn: PolicyCallbackFun,
seed: Optional[int] = None,
horizon: Optional[int] = None,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TrackingFootPositionsReward,
TrackingFootOrientationsReward,
TrackingFootForceDistributionReward,
DriftTrackingBaseOdometryPoseReward,
DriftTrackingBaseOdometryPositionTermination,
DriftTrackingBaseOdometryOrientationTermination,
ShiftTrackingFootOdometryPositionsTermination,
Expand Down Expand Up @@ -48,6 +49,7 @@
"TrackingFootPositionsReward",
"TrackingFootOrientationsReward",
"TrackingFootForceDistributionReward",
"DriftTrackingBaseOdometryPoseReward",
"DriftTrackingQuantityTermination",
"DriftTrackingBaseOdometryPositionTermination",
"DriftTrackingBaseOdometryOrientationTermination",
Expand Down
77 changes: 29 additions & 48 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
and the application (locomotion, grasping...).
"""
from functools import partial
from operator import sub, itemgetter
from operator import sub
from dataclasses import dataclass
from typing import Optional, Callable, Tuple, Sequence, Union, TypeVar

Expand All @@ -17,7 +17,7 @@
AbstractTerminationCondition, QuantityTermination, partial_hashable)
from ..bases.compositions import ArrayOrScalar, Number
from ..quantities import (
EnergyGenerationMode, StackedQuantity, UnaryOpQuantity, BinaryOpQuantity,
EnergyGenerationMode, StackedQuantity, BinaryOpQuantity, DeltaQuantity,
MultiActuatedJointKinematic, MechanicalPowerConsumption,
AverageMechanicalPowerConsumption)

Expand Down Expand Up @@ -192,8 +192,8 @@ def __init__(


@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def compute_drift_error(delta_true: np.ndarray,
delta_ref: np.ndarray) -> float:
def compute_drift_error(delta_true: Union[np.ndarray, float],
delta_ref: Union[np.ndarray, float]) -> float:
"""Compute the difference between the true and reference variation of a
quantity over a given horizon, then apply some post-processing on it if
requested.
Expand All @@ -213,11 +213,8 @@ class DriftTrackingQuantityTermination(QuantityTermination):
current and reference values of a given quantity over a horizon.
The drift is defined as the difference between the current and reference
variation of the quantity over a variable-length horizon bounded by
'max_stack'. This variation is computed from the whole history of values
corresponding to this horizon, which is basically a sliding window. For
Euclidean spaces, this variation is simply computed as the difference
between most recent and oldest values stored in the history.
variation of the quantity over a sliding window of length 'horizon'. See
`DeltaQuantity` quantity for details.
In practice, no bound check is applied on the drift directly, which may be
multi-variate at this point. Instead, the L2-norm is used as metric in the
Expand Down Expand Up @@ -290,52 +287,36 @@ def __init__(self,
"""
# pylint: disable=unnecessary-lambda-assignment

# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)

# Backup user argument(s)
self.max_stack = max_stack
self.horizon = horizon
self.op = op
self.bounds_only = bounds_only

# Define drift of quantity
delta_creator: Callable[[QuantityEvalMode], QuantityCreator]
if self.bounds_only:
stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731
quantity=quantity_creator(mode),
max_stack=max_stack))
delta_creator = lambda mode: (BinaryOpQuantity, dict( # noqa: E731
quantity_left=(UnaryOpQuantity, dict(
quantity=stack_creator(mode),
op=itemgetter(-1))),
quantity_right=(UnaryOpQuantity, dict(
quantity=stack_creator(mode),
op=itemgetter(0))),
op=op))
else:
delta_creator = lambda mode: (UnaryOpQuantity, dict( # noqa: E731
quantity=(StackedQuantity, dict(
quantity=quantity_creator(mode),
max_stack=max_stack,
is_wrapping=False,
as_array=True)),
op=op))
# Define variation quantity
delta_creator: Callable[
[QuantityEvalMode], QuantityCreator[ArrayOrScalar]]
delta_creator = lambda mode: (DeltaQuantity, dict( # noqa: E731
quantity=quantity_creator(mode),
horizon=horizon,
op=op,
bounds_only=bounds_only))

# Add drift quantity to the set of quantities managed by environment
# Define drift quantity
drift_tracking_quantity = (BinaryOpQuantity, dict(
quantity_left=delta_creator(QuantityEvalMode.TRUE),
quantity_right=delta_creator(QuantityEvalMode.REFERENCE),
op=compute_drift_error))

# Call base implementation
super().__init__(env,
name,
drift_tracking_quantity, # type: ignore[arg-type]
None,
np.array(thr),
grace_period,
is_truncation=is_truncation,
training_only=training_only)
super().__init__(
env,
name,
drift_tracking_quantity, # type: ignore[arg-type]
None,
np.array(thr),
grace_period,
is_truncation=is_truncation,
training_only=training_only)


@nb.jit(nopython=True, cache=True, fastmath=True)
Expand Down Expand Up @@ -444,13 +425,13 @@ def __init__(self,
"""
# pylint: disable=unnecessary-lambda-assignment

# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)

# Backup user argument(s)
self.max_stack = max_stack
self.horizon = horizon
self.op = op

# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)

# Define drift of quantity
stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731
quantity=quantity_creator(mode),
Expand Down
142 changes: 64 additions & 78 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/locomotion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Rewards mainly relevant for locomotion tasks on floating-base robots.
"""
from functools import partial
from operator import attrgetter
from dataclasses import dataclass
from typing import Optional, Union, Sequence, Literal, Callable, cast

Expand All @@ -12,26 +11,28 @@
import pinocchio as pin

from ..bases import (
InterfaceJiminyEnv, StateQuantity, InterfaceQuantity, QuantityEvalMode,
QuantityReward)
InterfaceJiminyEnv, InterfaceQuantity, QuantityEvalMode, QuantityReward)
from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar
from ..quantities import (
OrientationType, MaskedQuantity, UnaryOpQuantity, FrameOrientation,
BaseRelativeHeight, BaseOdometryPose, BaseOdometryAverageVelocity,
CapturePoint, MultiFramePosition, MultiFootRelativeXYZQuat,
MultiContactNormalizedSpatialForce, MultiFootNormalizedForceVertical,
MultiFootCollisionDetection, AverageBaseMomentum)
OrientationType, MaskedQuantity, UnaryOpQuantity, ConcatenatedQuantity,
FrameOrientation, BaseRelativeHeight, BaseOdometryPose,
DeltaBaseOdometryPosition, DeltaBaseOdometryOrientation,
BaseOdometryAverageVelocity, CapturePoint, MultiFramePosition,
MultiFootRelativeXYZQuat, MultiContactNormalizedSpatialForce,
MultiFootNormalizedForceVertical, MultiFootCollisionDetection,
AverageBaseMomentum)
from ..utils import quat_difference, quat_to_yaw

from .generic import (
TrackingQuantityReward, QuantityTermination,
DriftTrackingQuantityTermination, ShiftTrackingQuantityTermination)
from ..quantities.locomotion import angle_total
from .mixin import radial_basis_function


class TrackingBaseHeightReward(TrackingQuantityReward):
"""Reward the agent for tracking the height of the floating base of the
robot wrt some reference trajectory.
robot relative to lowest contact point wrt some reference trajectory.
.. seealso::
See `TrackingQuantityReward` documentation for technical details.
Expand All @@ -46,14 +47,7 @@ def __init__(self,
super().__init__(
env,
"reward_tracking_base_height",
lambda mode: (MaskedQuantity, dict(
quantity=(UnaryOpQuantity, dict(
quantity=(StateQuantity, dict(
update_kinematics=False,
mode=mode)),
op=attrgetter("q"))),
axis=0,
keys=(2,))),
lambda mode: (BaseRelativeHeight, dict(mode=mode)),
cutoff)


Expand All @@ -78,6 +72,38 @@ def __init__(self,
cutoff)


class DriftTrackingBaseOdometryPoseReward(TrackingQuantityReward):
"""Reward the agent for tracking the drift of the odometry pose over a
horizon wrt some reference trajectory.
.. seealso::
See `DeltaBaseOdometryPosition`, `DeltaBaseOdometryOrientation` and
`TrackingQuantityReward` documentations for technical details.
"""
def __init__(self,
env: InterfaceJiminyEnv,
cutoff: float,
horizon: float) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param cutoff: Cutoff threshold for the RBF kernel transform.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the drift.
"""
super().__init__(
env,
"reward_tracking_odometry_pose",
lambda mode: (ConcatenatedQuantity, dict(
quantities=(
(DeltaBaseOdometryPosition, dict(
horizon=horizon,
mode=mode)),
(DeltaBaseOdometryOrientation, dict(
horizon=horizon,
mode=mode))))),
cutoff)


class TrackingCapturePointReward(TrackingQuantityReward):
"""Reward the agent for tracking the capture point wrt some reference
trajectory.
Expand Down Expand Up @@ -347,7 +373,8 @@ def __init__(self,
super().__init__(
env,
"termination_base_height",
(BaseRelativeHeight, {}), # type: ignore[arg-type]
(BaseRelativeHeight, dict( # type: ignore[arg-type]
mode=QuantityEvalMode.TRUE)),
min_base_height,
None,
grace_period,
Expand Down Expand Up @@ -640,66 +667,6 @@ def __init__(self,
training_only=training_only)


@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def angle_difference(delta: ArrayOrScalar) -> ArrayOrScalar:
"""Compute the signed element-wise difference (aka. oriented angle) between
two batches of angles.
The oriented angle is defined as the smallest angle in absolute value
between right and left angles (ignoring multi-turns), signed in accordance
with the angle going from right to left angles.
.. seealso::
This proposed implementation is the most efficient one for batch size
of 1000. See this posts for reference about other implementations:
https://stackoverflow.com/a/7869457/4820605
:param delta: Pre-computed difference between left and right angles.
"""
return delta - np.floor((delta + np.pi) / (2 * np.pi)) * (2 * np.pi)


@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def angle_distance(angle_left: ArrayOrScalar,
angle_right: ArrayOrScalar) -> ArrayOrScalar:
"""Compute the element-wise distance between two batches of angles.
The distance is defined as the smallest angle in absolute value between
right and left angles (ignoring multi-turns).
.. seealso::
See `angle_difference` documentation for details.
:param angle_left: Left-hand side angles.
:param angle_right: Right-hand side angles.
"""
delta = angle_left - angle_right
delta -= np.floor(delta / (2 * np.pi)) * (2 * np.pi)
return np.pi - np.abs(delta - np.pi)


@nb.jit(nopython=True, cache=True, fastmath=True)
def angle_total(angles: np.ndarray) -> np.ndarray:
"""Compute the total signed multi-turn angle from start to end of
time-series of angles.
The method is fully compliant with individual angles restricted between
[-pi, pi], but it requires the distance between the angles at successive
timesteps to be smaller than pi.
.. seealso::
See `angle_difference` documentation for details.
:param angle: Temporal sequence of angles as a multi-dimensional array
whose last dimension gathers all the successive timesteps.
"""
# Note that `angle_difference` has been manually inlined as it results in
# about 50% speedup, which is surprising.
delta = angles[..., 1:] - angles[..., :-1]
delta -= np.floor((delta + np.pi) / (2.0 * np.pi)) * (2 * np.pi)
return np.sum(delta)


class DriftTrackingBaseOdometryOrientationTermination(
DriftTrackingQuantityTermination):
"""Terminate the episode if the current base odometry orientation is
Expand Down Expand Up @@ -808,6 +775,25 @@ def __init__(self,
training_only=training_only)


@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def angle_distance(angle_left: ArrayOrScalar,
angle_right: ArrayOrScalar) -> ArrayOrScalar:
"""Compute the element-wise distance between two batches of angles.
The distance is defined as the smallest angle in absolute value between
right and left angles (ignoring multi-turns).
.. seealso::
See `angle_difference` documentation for details.
:param angle_left: Left-hand side angles.
:param angle_right: Right-hand side angles.
"""
delta = angle_left - angle_right
delta -= np.floor(delta / (2 * np.pi)) * (2 * np.pi)
return np.pi - np.abs(delta - np.pi)


class ShiftTrackingFootOdometryOrientationsTermination(
ShiftTrackingQuantityTermination):
"""Terminate the episode if the selected reference trajectory is not
Expand Down
Loading

0 comments on commit aee83a2

Please sign in to comment.