Skip to content

Commit

Permalink
[gym_jiminy/common] Remove LRU cache for trajectory state getter as i…
Browse files Browse the repository at this point in the history
…t never hits in practice.
  • Loading branch information
duburcqa committed Feb 2, 2025
1 parent cc86beb commit 011ccb1
Showing 1 changed file with 39 additions and 69 deletions.
108 changes: 39 additions & 69 deletions python/jiminy_py/src/jiminy_py/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
# pylint: disable=invalid-name,no-member
import logging
from bisect import bisect_left
from functools import lru_cache
from dataclasses import dataclass, fields
from typing import (
List, Union, Optional, Tuple, Sequence, Dict, Callable, Literal)
from typing import List, Union, Optional, Tuple, Sequence, Callable, Literal

import numpy as np

Expand Down Expand Up @@ -233,9 +231,6 @@ def __init__(self,
fields_.append(field)
self._fields = tuple(fields_)

# Hacky way to enable argument-based function caching at instance-level
self.__dict__['_get'] = lru_cache(maxsize=None)(self._get)

@property
def has_data(self) -> bool:
"""Whether the trajectory has data, ie the state sequence is not empty.
Expand Down Expand Up @@ -296,16 +291,46 @@ def time_interval(self) -> Tuple[float, float]:
"State sequence is empty. Time interval undefined.")
return (self._times[0], self._times[-1])

def _get(self, t: float) -> Dict[str, np.ndarray]:
def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""Query the state at a given timestamp.
.. note::
This method is used internally by `get`. It is not meant to be
called manually.
Internally, the nearest neighbor states are linearly interpolated,
taking into account the corresponding Lie Group of all state attributes
that are available.
:param t: Time of the state to extract from the trajectory.
:param mode: Fallback strategy when the query time is not in the time
interval 'time_interval' of the trajectory. 'raise' raises
an exception if the query time is out-of-bound wrt the
underlying state sequence of the selected trajectory.
'clip' forces clipping of the query time before
interpolation of the state sequence. 'wrap' wraps around
the query time wrt the time span of the trajectory. This
is useful to store periodic trajectories as finite state
sequences.
"""
# pylint: disable=possibly-used-before-assignment
# Raise exception if state sequence is empty
if not self.has_data:
raise RuntimeError(
"State sequence is empty. Impossible to interpolate data.")

# Backup the original query time
t_orig = t

# Handling of the desired mode
n_steps = 0.0
t_start, t_end = self.time_interval
if mode == "raise":
if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL:
raise RuntimeError("Time is out-of-range.")
elif mode == "wrap":
if t_end > t_start:
n_steps, t_rel = divmod(t - t_start, t_end - t_start)
t = t_rel + t_start
else:
t = t_start
else:
t = min(max(t, t_start), t_end)

# Get nearest neighbors timesteps for linear interpolation.
# Note that the left and right data points may be associated with the
Expand All @@ -330,7 +355,7 @@ def _get(self, t: float) -> Dict[str, np.ndarray]:
return_right = t_right - t < TRAJ_INTERP_TOL
alpha = (t - t_left) / (t_right - t_left)

# Interpolate state
# Interpolate state data
if return_left:
position = s_left.q.copy()
elif return_right:
Expand All @@ -350,69 +375,14 @@ def _get(self, t: float) -> Dict[str, np.ndarray]:
else:
data[field] = value_left + alpha * (value_right - value_left)

# Make sure that data are immutable.
# This is essential to make sure that cached values cannot be altered.
for arr in data.values():
arr.setflags(write=False)

return data

def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""Query the state at a given timestamp.
Internally, the nearest neighbor states are linearly interpolated,
taking into account the corresponding Lie Group of all state attributes
that are available.
:param t: Time of the state to extract from the trajectory.
:param mode: Fallback strategy when the query time is not in the time
interval 'time_interval' of the trajectory. 'raise' raises
an exception if the query time is out-of-bound wrt the
underlying state sequence of the selected trajectory.
'clip' forces clipping of the query time before
interpolation of the state sequence. 'wrap' wraps around
the query time wrt the time span of the trajectory. This
is useful to store periodic trajectories as finite state
sequences.
"""
# Raise exception if state sequence is empty
if not self.has_data:
raise RuntimeError(
"State sequence is empty. Impossible to interpolate data.")

# Backup the original query time
t_orig = t

# Handling of the desired mode
n_steps = 0.0
t_start, t_end = self.time_interval
if mode == "raise":
if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL:
raise RuntimeError("Time is out-of-range.")
elif mode == "wrap":
if t_end > t_start:
n_steps, t_rel = divmod(t - t_start, t_end - t_start)
t = t_rel + t_start
else:
t = t_start
else:
t = min(max(t, t_start), t_end)

# Rounding time to avoid cache miss issues
# Note that `int(x + 0.5)` is faster than `round(x)`.
t = int(t / TRAJ_INTERP_TOL + 0.5) * TRAJ_INTERP_TOL

# Interpolate state at the desired time
state = State(t=t_orig, **self._get(t))

# Perform odometry if time is wrapping
if self._stride_offset_log6 is not None and n_steps:
state.q = position = state.q.copy()
stride_offset = pin.exp6(n_steps * self._stride_offset_log6)
ff_xyzquat = stride_offset * pin.XYZQUATToSE3(position[:7])
position[:7] = pin.SE3ToXYZQUAT(ff_xyzquat)

return state
# Return state instances bundling all data
return State(t=t_orig, **data)


# #####################################################################
Expand Down

0 comments on commit 011ccb1

Please sign in to comment.