Skip to content

Commit

Permalink
[gym_jiminy/common] Inline 'InterfaceQuantity.refresh' if possible fo…
Browse files Browse the repository at this point in the history
…r efficiency.
  • Loading branch information
duburcqa committed Jan 28, 2025
1 parent d355878 commit f1cc63f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
23 changes: 12 additions & 11 deletions python/gym_jiminy/common/gym_jiminy/common/quantities/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,12 @@ def initialize(self) -> None:
self.data.reset(reset_tracking=True)

def refresh(self) -> np.ndarray:
value = self.data.get()
return value[self.frame_name]
# Return a slice of batched data.
# Note that mapping from frame names to frame index in batched data
# cannot be pre-computed as it may changed dynamically.
# Note that avoiding defining a temporary variable to store the current
# value of the quantity slightly improves performance.
return self.data.get()[self.frame_name]


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -588,11 +592,7 @@ def initialize(self) -> None:
self.data.reset(reset_tracking=True)

def refresh(self) -> np.ndarray:
# Return a slice of batched data.
# Note that mapping from frame names to frame index in batched data
# cannot be pre-computed as it may changed dynamically.
value = self.data.get()
return value[self.frame_names]
return self.data.get()[self.frame_names]


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -742,8 +742,7 @@ def initialize(self) -> None:
self.data.reset(reset_tracking=True)

def refresh(self) -> np.ndarray:
value = self.data.get()
return value[self.frame_name]
return self.data.get()[self.frame_name]


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -806,8 +805,7 @@ def initialize(self) -> None:
self.data.reset(reset_tracking=True)

def refresh(self) -> np.ndarray:
value = self.data.get()
return value[self.frame_names]
return self.data.get()[self.frame_names]


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -1807,6 +1805,9 @@ def __init__(
compute_power, int(self.generator_mode))))),
auto_refresh=False)

# Enable direct forwarding (inlining) for efficiency
self.__dict__["refresh"] = self.data.get

def refresh(self) -> float:
return self.data.get()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def __init__(self,
keys=(0, 1, 5)))),
auto_refresh=False)

# Enable direct forwarding (inlining) for efficiency
self.__dict__["refresh"] = self.data.get

def refresh(self) -> np.ndarray:
return self.data.get()

Expand Down Expand Up @@ -468,6 +471,9 @@ def __init__(self,
mode=mode))),
auto_refresh=False)

# Enable direct forwarding (inlining) for efficiency
self.__dict__["refresh"] = self.data.get

def refresh(self) -> np.ndarray:
return self.data.get()

Expand Down Expand Up @@ -666,6 +672,9 @@ def __init__(self,
data=position_vector),
auto_refresh=False)

# Enable direct forwarding (inlining) for efficiency
self.__dict__["refresh"] = self.data.get

def refresh(self) -> np.ndarray:
return self.data.get()

Expand Down Expand Up @@ -1623,6 +1632,9 @@ def __init__(self,
bounds_only=True))),
auto_refresh=False)

# Enable direct forwarding (inlining) for efficiency
self.__dict__["refresh"] = self.data.get

def refresh(self) -> ArrayOrScalar:
return self.data.get()

Expand Down Expand Up @@ -1685,5 +1697,8 @@ def __init__(self,
bounds_only=False))),
auto_refresh=False)

# Enable direct forwarding (inlining) for efficiency
self.__dict__["refresh"] = self.data.get

def refresh(self) -> ArrayOrScalar:
return self.data.get()

0 comments on commit f1cc63f

Please sign in to comment.