Skip to content

Commit

Permalink
[python/dynamics] Fix bug making 'Trajectory.get' extremely inefficient.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Feb 7, 2025
1 parent ec07134 commit 42c0e7b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
3 changes: 0 additions & 3 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def _array_clip(value: np.ndarray,
:param low: Optional lower bound.
:param high: Optional upper bound.
"""
# Note that in-place clipping is actually slower than out-of-place in
# Numba when 'fastmath' compilation flag is set.

# Short circuit if there is neither low or high bounds
if low is None and high is None:
return value.copy()
Expand Down
39 changes: 20 additions & 19 deletions python/jiminy_py/src/jiminy_py/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import logging
from bisect import bisect_left
from dataclasses import dataclass, fields
from typing import List, Union, Optional, Tuple, Sequence, Callable, Literal
from typing import (
List, Union, Optional, Tuple, Sequence, Callable, Dict, Literal)

import numpy as np

Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(self,
self._t_prev = 0.0
self._index_prev = 1

# List of optional state fields that are provided
# List of optional state fields that have been specified.
# Note that looking for keys in such a small set is not worth the
# hassle of using Python `set`, which breaks ordering and index access.
fields_: List[str] = []
Expand All @@ -227,7 +228,7 @@ def __init__(self,
raise ValueError(
"The state information being set must be the same "
"for all the timesteps of a given trajectory.")
else:
elif field not in fields_:
fields_.append(field)
self._fields = tuple(fields_)

Expand Down Expand Up @@ -286,10 +287,11 @@ def time_interval(self) -> Tuple[float, float]:
It raises an exception if no data is available.
"""
if not self.has_data:
raise RuntimeError(
try:
return (self._times[0], self._times[-1])
except IndexError:
raise RuntimeError( # pylint: disable=raise-missing-from
"State sequence is empty. Time interval undefined.")
return (self._times[0], self._times[-1])

def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""Query the state at a given timestamp.
Expand All @@ -311,20 +313,19 @@ def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""
# pylint: disable=possibly-used-before-assignment

# Raise exception if state sequence is empty
if not self.has_data:
raise RuntimeError(
"State sequence is empty. Impossible to interpolate data.")

# Backup the original query time
t_orig = t

# Handling of the desired mode
n_steps = 0.0
t_start, t_end = self.time_interval
try:
t_start, t_end = self._times[0], self._times[-1]
except IndexError:
raise RuntimeError( # pylint: disable=raise-missing-from
"State sequence is empty. Impossible to interpolate data.")
if mode == "raise":
if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL:
raise RuntimeError("Time is out-of-range.")
raise RuntimeError("Query time out-of-range.")
elif mode == "wrap":
if t_end > t_start:
n_steps, t_rel = divmod(t - t_start, t_end - t_start)
Expand Down Expand Up @@ -365,26 +366,26 @@ def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
else:
position = pin.interpolate(
self._pinocchio_model, s_left.q, s_right.q, alpha)
data = {"q": position}
state: Dict[str, Union[float, np.ndarray]] = dict(t=t_orig, q=position)
for field in self._fields:
value_left = getattr(s_left, field)
if return_left:
data[field] = value_left.copy()
state[field] = value_left.copy()
continue
value_right = getattr(s_right, field)
if return_right:
data[field] = value_right.copy()
state[field] = value_right.copy()
else:
data[field] = value_left + alpha * (value_right - value_left)
state[field] = value_left + alpha * (value_right - value_left)

# Perform odometry if time is wrapping
if self._stride_offset_log6 is not None and n_steps:
stride_offset = pin.exp6(n_steps * self._stride_offset_log6)
ff_xyzquat = stride_offset * pin.XYZQUATToSE3(position[:7])
position[:7] = pin.SE3ToXYZQUAT(ff_xyzquat)

# Return state instances bundling all data
return State(t=t_orig, **data)
# Return a State object
return State(**state) # type: ignore[arg-type]


# #####################################################################
Expand Down

0 comments on commit 42c0e7b

Please sign in to comment.