Skip to content

Commit

Permalink
[gym_jiminy/common] Fix quantity caching issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Jan 27, 2025
1 parent ff3d31f commit 6795e95
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 367 deletions.
6 changes: 4 additions & 2 deletions python/gym_jiminy/common/gym_jiminy/common/bases/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
MixtureReward,
AbstractTerminationCondition,
QuantityTermination,
EpisodeState)
EpisodeState,
partial_hashable)
from .blocks import (BlockState,
InterfaceBlock,
BaseObserverBlock,
Expand Down Expand Up @@ -74,5 +75,6 @@
'QuantityCreator',
'EpisodeState',
'StateQuantity',
'DatasetTrajectoryQuantity'
'DatasetTrajectoryQuantity',
'partial_hashable'
]
53 changes: 52 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
This modular approach allows for standardization of usual metrics. Overall, it
greatly reduces code duplication and bugs.
"""
import inspect
from functools import partial
from abc import abstractmethod, ABCMeta
from enum import IntEnum
from typing import Tuple, Sequence, Callable, Union, Optional, Generic, TypeVar
from typing import (
Tuple, Sequence, Callable, Union, Optional, Generic, Any, TypeVar,
TYPE_CHECKING)

import numpy as np

Expand All @@ -25,6 +29,53 @@
ArrayLikeOrScalar = Union[ArrayOrScalar, Sequence[Union[Number, np.number]]]


class partial_hashable(partial):
"""Extends standard `functools.Partial` class with hash and equality
operator.
Two partial instances are equal if they are wrapping the exact same
function (i.e. pointing to the same memory address as per `id` build-in
function), and bindings the same arguments (i.e. all arguments are equal
as per `==` operator). Note that it does not matter if the constructor
arguments of `Partial` itself are positional or keyword-based. Internally,
they will be stored in an ordered list of keyword-only arguments for
equality check.
.. warning::
Try to instantiate this class with invalid arguments for the method
being wrapped (e.g. specifying multiple values for the same argument)
would raise a `TypeError` exception, unlike `functools.partial` that
would only fail when calling the resulting callable object.
"""

if TYPE_CHECKING:
_normalized_args: Tuple[Any, ...]

def __new__(cls,
func: Callable, /,
*args: Any,
**kwargs: Any) -> "partial_hashable":
# Call base implementation
self = super(partial_hashable, cls).__new__(cls, func, *args, **kwargs)

# Pre-compute normalized arguments once and for all
sig = inspect.signature(self.func)
bound = sig.bind_partial(*self.args, **(self.keywords or {}))
bound.apply_defaults()
self._normalized_args = tuple(bound.arguments.values())

return self

def __eq__(self, other: Any) -> bool:
if not isinstance(other, partial_hashable):
return False
return self.func == other.func and (
self._normalized_args == other._normalized_args)

def __hash__(self) -> int:
return hash((self.func, self._normalized_args))


class AbstractReward(metaclass=ABCMeta):
"""Abstract class from which all reward component must derived.
Expand Down
119 changes: 57 additions & 62 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..bases import (
InfoType, QuantityCreator, InterfaceJiminyEnv, InterfaceQuantity,
QuantityEvalMode, AbstractReward, QuantityReward,
AbstractTerminationCondition, QuantityTermination)
AbstractTerminationCondition, QuantityTermination, partial_hashable)
from ..bases.compositions import ArrayOrScalar, Number
from ..quantities import (
EnergyGenerationMode, StackedQuantity, UnaryOpQuantity, BinaryOpQuantity,
Expand Down Expand Up @@ -191,6 +191,23 @@ def __init__(
is_terminal=False)


@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def compute_drift_error(delta_true: np.ndarray,
delta_ref: np.ndarray) -> float:
"""Compute the difference between the true and reference variation of a
quantity over a given horizon, then apply some post-processing on it if
requested.
:param delta_true: True value of the variation as a N-dimensional array.
:param delta_ref: Reference value of the variation as a N-dimensional
array.
"""
drift = delta_true - delta_ref
if isinstance(drift, float):
return abs(drift)
return np.linalg.norm(drift.reshape((-1,))) # type: ignore[return-value]


class DriftTrackingQuantityTermination(QuantityTermination):
"""Base class to derive termination condition from the drift between the
current and reference values of a given quantity over a horizon.
Expand Down Expand Up @@ -223,8 +240,6 @@ def __init__(self,
*,
op: Callable[[ArrayLike, ArrayLike], ArrayOrScalar] = sub,
bounds_only: bool = True,
post_fn: Optional[Callable[
[ArrayOrScalar], ArrayOrScalar]] = None,
is_truncation: bool = False,
training_only: bool = False) -> None:
"""
Expand Down Expand Up @@ -283,24 +298,6 @@ def __init__(self,
self.op = op
self.bounds_only = bounds_only

# Jit-able method for computing the drift error
@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def compute_drift_error(delta_true: ArrayOrScalar,
delta_ref: ArrayOrScalar) -> float:
"""Compute the difference between the true and reference variation
of a quantity over a given horizon, then apply some post-processing
on it if requested.
:param delta_true: True value of the variation as a N-dimensional
array.
:param delta_ref: Reference value of the variation as a
N-dimensional array.
"""
drift = delta_true - delta_ref
if isinstance(drift, float):
return abs(drift)
return np.linalg.norm(drift.reshape((-1,)))

# Define drift of quantity
delta_creator: Callable[[QuantityEvalMode], QuantityCreator]
if self.bounds_only:
Expand Down Expand Up @@ -341,6 +338,44 @@ def compute_drift_error(delta_true: ArrayOrScalar,
training_only=training_only)


@nb.jit(nopython=True, cache=True, fastmath=True)
def min_norm(values: np.ndarray) -> float:
"""Compute the minimum Euclidean norm over all timestamps of a multivariate
time series.
:param values: Time series as a N-dimensional array whose last dimension
corresponds to individual timestamps over a finite horizon.
The value at each timestamp will be regarded as a 1D vector
for computing their Euclidean norm.
"""
num_times = values.shape[-1]
values_squared_flat = np.square(values).reshape((-1, num_times))
return np.sqrt(np.min(np.sum(values_squared_flat, axis=0)))


def compute_min_distance(op: Callable[[np.ndarray, np.ndarray], np.ndarray],
left: np.ndarray,
right: np.ndarray) -> float:
"""Compute the minimum time-aligned Euclidean distance between two
multivariate time series kept in sync.
Internally, the time-aligned difference between the two time series will
first be computed according to the user-specified binary operator 'op'. The
classical Euclidean norm of the difference is then computed over all
timestamps individually and the minimum value is returned.
:param left: Time series as a N-dimensional array whose first dimension
corresponds to individual timestamps over a finite horizon.
The value at each timestamp will be regarded as a 1D vector
for computing their Euclidean norm. It will be passed as
left-hand side of the binary operator 'op'.
:param right: Time series as a N-dimensional array with the exact same
shape as 'left'. See 'left' for details. It will be passed as
right-hand side of the binary operator 'op'.
"""
return min_norm(op(left, right))


class ShiftTrackingQuantityTermination(QuantityTermination):
"""Base class to derive termination condition from the shift between the
current and reference values of a given quantity.
Expand Down Expand Up @@ -416,24 +451,6 @@ def __init__(self,
self.max_stack = max_stack
self.op = op

# Jit-able method computing minimum distance between two time series
@nb.jit(nopython=True, cache=True)
def min_norm(values: np.ndarray) -> float:
"""Compute the minimum Euclidean norm over all timestamps of a
multivariate time series.
:param values: Time series as a N-dimensional array whose last
dimension corresponds to individual timestamps over
a finite horizon. The value at each timestamp will
be regarded as a 1D vector for computing their
Euclidean norm.
"""
num_times = values.shape[-1]
values_squared_flat = np.square(values).reshape((-1, num_times))
return np.sqrt(np.min(np.sum(values_squared_flat, axis=0)))

self._min_norm = min_norm

# Define drift of quantity
stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731
quantity=quantity_creator(mode),
Expand All @@ -445,7 +462,7 @@ def min_norm(values: np.ndarray) -> float:
shift_tracking_quantity = (BinaryOpQuantity, dict(
quantity_left=stack_creator(QuantityEvalMode.TRUE),
quantity_right=stack_creator(QuantityEvalMode.REFERENCE),
op=self._compute_min_distance))
op=partial_hashable(compute_min_distance, op)))

# Call base implementation
super().__init__(env,
Expand All @@ -457,28 +474,6 @@ def min_norm(values: np.ndarray) -> float:
is_truncation=is_truncation,
training_only=training_only)

def _compute_min_distance(self,
left: np.ndarray,
right: np.ndarray) -> float:
"""Compute the minimum time-aligned Euclidean distance between two
multivariate time series kept in sync.
Internally, the time-aligned difference between the two time series
will first be computed according to the user-specified binary operator
'op'. The classical Euclidean norm of the difference is then computed
over all timestamps individually and the minimum value is returned.
:param left: Time series as a N-dimensional array whose first dimension
corresponds to individual timestamps over a finite
horizon. The value at each timestamp will be regarded as a
1D vector for computing their Euclidean norm. It will be
passed as left-hand side of the binary operator 'op'.
:param right: Time series as a N-dimensional array with the exact same
shape as 'left'. See 'left' for details. It will be
passed as right-hand side of the binary operator 'op'.
"""
return self._min_norm(self.op(left, right))


@dataclass(unsafe_hash=True)
class _MultiActuatedJointBoundDistance(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,32 @@ def __init__(self,
training_only=training_only)


@nb.jit(nopython=True, cache=True, fastmath=True)
def min_depth(positions: np.ndarray,
heights: np.ndarray,
normals: np.ndarray) -> float:
"""Approximate minimum distance from the ground profile among a set of the
query points.
Internally, it uses a first order approximation assuming zero local
curvature around each query point.
:param positions: Position of all the query points from which to compute
from the ground profile, as a 2D array whose first
dimension gathers the 3 position coordinates (X, Y, Z)
while the second correponds to the N individual query
points.
:param heights: Vertical height wrt the ground profile of the N individual
query points in world frame as 1D array.
:param normals: Normal of the ground profile for the projection in world
plane of all the query points, as a 2D array whose first
dimension gathers the 3 position coordinates (X, Y, Z)
while the second correponds to the N individual query
points.
"""
return np.min((positions[2] - heights) * normals[2])


@dataclass(unsafe_hash=True)
class _MultiContactMinGroundDistance(InterfaceQuantity[float]):
"""Minimum distance from the ground profile among all the contact points.
Expand Down Expand Up @@ -440,34 +466,6 @@ def __init__(self,
))),
auto_refresh=False)

# Jit-able method computing the minimum first-order depth
@nb.jit(nopython=True, cache=True, fastmath=True)
def min_depth(positions: np.ndarray,
heights: np.ndarray,
normals: np.ndarray) -> float:
"""Approximate minimum distance from the ground profile among a set
of the query points.
Internally, it uses a first order approximation assuming zero local
curvature around each query point.
:param positions: Position of all the query points from which to
compute from the ground profile, as a 2D array
whose first dimension gathers the 3 position
coordinates (X, Y, Z) while the second correponds
to the N individual query points.
:param heights: Vertical height wrt the ground profile of the N
individual query points in world frame as 1D array.
:param normals: Normal of the ground profile for the projection in
world plane of all the query points, as a 2D array
whose first dimension gathers the 3 position
coordinates (X, Y, Z) while the second correponds
to the N individual query points.
"""
return np.min((positions[2] - heights) * normals[2])

self._min_depth = min_depth

# Reference to the heightmap function for the ongoing epsiode
self._heightmap = jiminy.HeightmapFunction(lambda: None)

Expand Down Expand Up @@ -496,7 +494,7 @@ def refresh(self) -> float:
# self._normals /= np.linalg.norm(self._normals, axis=0)

# First-order distance estimation assuming no curvature
return self._min_depth(positions, self._heights, self._normals)
return min_depth(positions, self._heights, self._normals)


class FlyingTermination(QuantityTermination):
Expand Down
Loading

0 comments on commit 6795e95

Please sign in to comment.