Skip to content

Commit

Permalink
[gym_jiminy/toolbox|rllib] Implement task score at scheduler level ra…
Browse files Browse the repository at this point in the history
…ther than task settable wrapper level.
  • Loading branch information
duburcqa committed Jan 31, 2025
1 parent cb64583 commit 517759d
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,9 @@ def pd_adapter(action: np.ndarray,
satisfied at all cost.
:param command_state_upper: Upper bound of the command state that must be
satisfied at all cost.
:param motors_velocity_deadband: Target velocity deadband for which the target
motor velocity will be cancelled out completely.
:param motors_velocity_deadband: Target velocity deadband for which the
target motor velocity will be cancelled
out completely.
:param step_dt: Time interval during which the target motor accelerations
will be held constant.
:param out: Pre-allocated memory to store the target motor accelerations.
Expand Down Expand Up @@ -480,7 +481,6 @@ def _setup(self) -> None:
# Reset the command state
fill(self._command_state, 0)


@property
def fieldnames(self) -> List[str]:
return [f"currentTarget{N_ORDER_DERIVATIVE_NAMES[2]}{motor.name}"
Expand Down
112 changes: 100 additions & 12 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@

""" TODO: Write documentation.
"""
from typing import List, Any, Dict, Tuple, Type, cast
import math
from collections import defaultdict
from typing import (
List, Any, Dict, Tuple, Type, Optional, Callable, DefaultDict, cast)

import numpy as np
import gymnasium as gym
Expand All @@ -14,7 +17,7 @@
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.typing import ResultDict, EpisodeType
from ray.rllib.utils.typing import ResultDict, EpisodeID, EpisodeType

from jiminy_py import tree
from gym_jiminy.common.bases import BasePipelineWrapper
Expand All @@ -27,14 +30,50 @@
) from e


def build_task_scheduling_callback(history_length: int,
softmin_beta: float
) -> Type[DefaultCallbacks]:
def build_task_scheduling_callback(
history_length: int,
softmin_beta: float,
score_fn: Optional[Callable[
[Tuple[EpisodeType, ...], gym.vector.VectorEnv, int], float
]] = None
) -> Type[DefaultCallbacks]:
""" TODO: Write documentation.
.. warning:
To use this callback, the base environment must wrapped with
`BaseTaskSettableWrapper` (but not necessarily as top-most layer).
:param history_length: Number of past episodes over which the average score
is computed for each node individually. Basically a
moving average over a sliding window. This time
constant must be at least one order of magnitude
slower than the update of the policy for the RL
problem to appear stationary from its standpoint.
:param softmin_beta: Inverse temperature parameter of the softmin formula
used to infer sampling probabilities for each task
from all their scores. For large beta, very small
difference in scores would be enough to make induce
large discrependies in probabilities between tasks.
However, if the distribution of tasks is very
unbalanced, then the policy may forget some skills
that were previously already learned.
:param score_fn:
Function used to score each episode with signature:
.. code-block:: python
score_fn(
episode_chunks: Tuple[
ray.rllib.utils.typing.EpisodeType, ...],
env: gym.vector.VectorEnv,
env_index: int
) -> float # score
`None` to use the standardized return, i.e. the undiscounted
cumulative reward over complete episodes divided by the maximum
number of steps of an episode before truncation. This means that
its value is ranging from 0.0 to 1.0 iif the reward is normalized.
Optional: `None` by default.
"""
class TaskSchedulingSamplingCallback(DefaultCallbacks):
""" TODO: Write documentation.
Expand Down Expand Up @@ -72,9 +111,20 @@ class TaskSchedulingSamplingCallback(DefaultCallbacks):
E[S12, S13, S14, S15, S16, S17])))
"""
def __init__(self) -> None:
# Unique ID of the ongoing episode for each environments being
# managed by the runner associated with this callback instance.
self._ongoing_episodes: Dict[int, EpisodeID] = {}

# Episodes that were started by never reached termination before
# the end of the previous sampling iteration.
self._partial_episodes: DefaultDict[
EpisodeID, List[EpisodeType]] = defaultdict(list)

# Whether to clear all task metrics at the end of the next episode
self._is_initialized = False
self._must_clear_metrics = False

self._is_initialized = False
self._max_num_steps_all: Tuple[int, ...] = ()
self._task_space = gym.spaces.Tuple([])
self._task_paths: Tuple[TaskPath, ...] = ()
self._task_names: Tuple[str, ...] = ()
Expand All @@ -101,6 +151,13 @@ def on_environment_created(self,
raise RuntimeError("Base environment must be wrapped with "
"`BaseTaskSettableWrapper`.") from e

# Get the maximum episode duration
self._max_num_steps_all = tuple(
math.ceil(simulation_duration_max / step_dt)
for simulation_duration_max, step_dt in zip(
env.unwrapped.get_attr("simulation_duration_max"),
env.unwrapped.get_attr("step_dt")))

# Pre-compute the list of all possible tasks
self._task_paths = cast(Tuple[TaskPath, ...], tuple(
(*path, i)
Expand All @@ -117,12 +174,28 @@ def on_environment_created(self,
# The callback is now fully initialized
self._is_initialized = True

def on_episode_start(self,
*,
episode: EpisodeType,
env_runner: EnvRunner,
metrics_logger: MetricsLogger,
env: gym.Env,
env_index: int,
rl_module: RLModule,
**kwargs: Any) -> None:
# Drop all partial episodes associated with the environment at hand
# when starting a fresh new one since it will never be done anyway.
if env_index in self._ongoing_episodes:
episode_id_prev = self._ongoing_episodes[env_index]
self._partial_episodes.pop(episode_id_prev, None)
self._ongoing_episodes[env_index] = episode.id_

def on_episode_end(self,
*,
episode: EpisodeType,
env_runner: EnvRunner,
metrics_logger: MetricsLogger,
env: gym.Env,
env: gym.vector.VectorEnv,
env_index: int,
rl_module: RLModule,
**kwargs: Any) -> None:
Expand All @@ -132,10 +205,19 @@ def on_episode_end(self,
metrics_logger.stats.pop("task_metrics", None)
self._must_clear_metrics = False

# Pop out task information from the episode to avoid monitoring it
task_index, score = -1, 0.0
for info in episode.get_infos():
task_index, score = info.pop("task", (task_index, score))
# Get all the chunks associated with the episode at hand
episodes = (*self._partial_episodes.pop(episode.id_, []), episode)

# Compute the score for the episode
if score_fn is None:
episode_return = sum(
episode.get_return() for episode in episodes)
score = episode_return / self._max_num_steps_all[env_index]
else:
score = score_fn(episodes, env, env_index)

# Pop task information out from the episode to avoid monitoring it
task_index = episodes[0].infos[0].pop("task_index")

# Update score history of all the nodes from root to leave for the
# task associated with the episode.
Expand All @@ -157,8 +239,14 @@ def on_sample_end(self,
metrics_logger: MetricsLogger,
samples: List[EpisodeType],
**kwargs: Any) -> None:
# Clear all metrics after sampling.
# Store all the partial episodes that did not reached done yet.
# See `MonitorEpisodeCallback.on_episode_end`.
for episode in samples:
if episode.is_done:
continue
self._partial_episodes[episode.id_].append(episode)

# Clear all metrics after sampling
self._must_clear_metrics = True

def on_train_result(self,
Expand Down
3 changes: 2 additions & 1 deletion python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def on_episode_start(self,
# Drop all partial episodes associated with the environment at hand
# when starting a fresh new one since it will never be done anyway.
if env_index in self._ongoing_episodes:
self._partial_episodes.pop(self._ongoing_episodes[env_index], None)
episode_id_prev = self._ongoing_episodes[env_index]
self._partial_episodes.pop(episode_id_prev, None)
self._ongoing_episodes[env_index] = episode.id_

def on_episode_end(self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# pylint: disable=missing-module-docstring

from .frame_rate_limiter import FrameRateLimiter
from .meta_envs import BaseTaskSettableWrapper
from .meta_envs import BaseTaskSettableWrapper, TrajectorySettableJiminyEnv


__all__ = [
"FrameRateLimiter",
"BaseTaskSettableWrapper"
"BaseTaskSettableWrapper",
"TrajectorySettableJiminyEnv"
]
120 changes: 77 additions & 43 deletions python/gym_jiminy/toolbox/gym_jiminy/toolbox/wrappers/meta_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
"""
from abc import abstractmethod
from typing import (
Any, Optional, List, Tuple, Sequence, Union, Generic, SupportsFloat,
TypeVar, cast)
Any, Optional, List, Tuple, Sequence, Dict, Union, Generic, TypeVar, cast)

import numpy as np

Expand Down Expand Up @@ -172,10 +171,7 @@ def __init__(self,
None))

# Enable direct forwarding by default for efficiency
methods_names = ["compute_command"]
if self.augment_observation and self.num_tasks:
methods_names.append("refresh_observation")
for method_name in methods_names:
for method_name in ("compute_command", "refresh_observation"):
method_orig = getattr(BaseTaskSettableWrapper, method_name)
method = getattr(type(self), method_name)
if method_orig is method:
Expand Down Expand Up @@ -322,40 +318,34 @@ def compute_command(self, action: Act, command: np.ndarray) -> None:
"""
self.env.compute_command(action, command)

def step(self, # type: ignore[override]
action: Act
) -> Tuple[DataNested, SupportsFloat, bool, bool, InfoType]:
"""Run a simulation step for a given action.
This method monitors the performance of the agent for the task at hand
at the end of the episode. This information is stored under key "task"
of `info` a tuple `(task, score)`.
:param terminated: Whether the episode has reached the terminal state
of the MDP at the current step. This flag can be
used to compute a specific terminal reward.
:param info: Dictionary of extra information for monitoring.
def reset(self, # type: ignore[override]
*,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None
) -> Tuple[DataNested, InfoType]:
"""Reset the unified environment.
In practice, all it doeson top of the original implementation is
storing under key `task_index` of the extra `info` output the current
task index. See `BasePipelineWrapper.reset` documentation for details.
:param seed: Random seed, as a positive integer.
Optional: `None` by default. If `None`, then the internal
random generator of the environment will be kept as-is,
without updating its seed.
:param options: Additional information to specify how the environment
is reset. The field 'reset_hook' is reserved for
chaining multiple `BasePipelineWrapper`. It is not
meant to be defined manually.
Optional: None by default.
"""
# Call base implementation
obs, reward, terminated, truncated, info = super().step(action)

# FIXME: Add the score of the agent at the end of the episode for the
# task at hand if any, otherwise just move on.
# Note that, at this point, the flags `terminated` and `truncated` have
# been evaluated for the top-most layer of the pipeline. As a result,
# the score is guarantee to be added at the end of the episode as long
# as no additional termination condition is triggered by some higher-
# level classical wrappers deriving from `gym.Wrapper`. Unfortunately,
# scenario is quite common, especially via `gym.wrappers.TimeLimit`.
# Because of this limitation, it is preferrable to store the score at
# every step.
if self.num_tasks: # and (terminated or truncated):
assert "task" not in info
score = float(self.get_score())
info["task"] = (int(self.task_index), score)

# Return total reward
return obs, reward, terminated, truncated, info
obs, info = super().reset(seed=seed, options=options)

# Store the current task index in extra information
info['task_index'] = int(self.task_index)

return obs, info

# methods to override:
# ----------------------------
Expand All @@ -370,10 +360,54 @@ def set_task(self, task_index: int) -> None:
"""Set the task that the agent will have to address from now on.
"""

@abstractmethod
def get_score(self) -> float:
"""Assess how well the agent is performing so far for the current task.

.. warning::
This score must be standardized between 0.0 and 1.0.
class TrajectorySettableJiminyEnv(
BaseTaskSettableWrapper[
Obs, Act, Union[np.int64, Tuple[()]], BaseObs],
Generic[Obs, Act, BaseObs]):

task_space: Union[spaces.Discrete, spaces.Tuple]

def __init__(self,
env: InterfaceJiminyEnv[BaseObs, Act],
*,
initial_proba_task_tree: Optional[Sequence[float]] = None,
augment_observation: bool = True
) -> None:
"""
:param env: Environment to extend, eventually already wrapped.
:param initial_proba_task_tree: Initial probability tree associated
with the task tree of the environment.
:param augment_observation: Whether to add the current task to the
observation of the environment.
Optional: `True` by default.
"""
# Make sure that the trajectory database is already locked
if not env.quantities.trajectory_dataset.is_locked:
raise RuntimeError(
"The trajectory dataset managed by the base environment must "
"be locked being wrapped by `TrajectorySettableJiminyEnv`.")

# Call base implementation
super().__init__(env,
initial_proba_task_tree=initial_proba_task_tree,
augment_observation=augment_observation)

# Make sure that the environment is derived from InterfaceJiminyEnv
assert isinstance(self, InterfaceJiminyEnv)

# Define proxy for fast access
self._trajectory_names = tuple(self.quantities.trajectory_dataset)
self._simulation_duration_max = (
self.env.unwrapped.simulation_duration_max)

def _initialize_task_space(self) -> None:
num_trajectories = len(self.quantities.trajectory_dataset)
if num_trajectories:
self.task_space = spaces.Discrete(num_trajectories)
else:
self.task_space = spaces.Tuple([])

def set_task(self, task_index: int) -> None:
trajectory_name = self._trajectory_names[task_index]
self.quantities.trajectory_dataset.select(trajectory_name)
Loading

0 comments on commit 517759d

Please sign in to comment.