Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.postprocessing.zero_padding import (
create_mask_and_seq_lens,
Expand Down Expand Up @@ -396,11 +395,3 @@ def _get_max_seq_len(self, rl_module, module_id=None):
"model_config={'max_seq_len': [some int]})`."
)
return mod.model_config["max_seq_len"]


@Deprecated(
new="ray.rllib.utils.postprocessing.zero_padding.split_and_zero_pad()",
error=True,
)
def split_and_zero_pad_list(*args, **kwargs):
pass
81 changes: 23 additions & 58 deletions rllib/connectors/connector_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.deprecation import Deprecated, deprecation_warning
from ray.rllib.utils.spaces.space_utils import BatchedNdArray
from ray.rllib.utils.typing import AgentID, EpisodeType, ModuleID, StateDict
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -97,11 +96,29 @@ def __init__(
self._action_space = None
self._input_observation_space = None
self._input_action_space = None
self._kwargs = kwargs

self.input_action_space = input_action_space
self.input_observation_space = input_observation_space

# Store child's constructor args and kwargs for the default
# `get_ctor_args_and_kwargs` implementation (to be able to restore from a
# checkpoint).
if self.__class__.__dict__.get("__init__") is not None:
caller_frame = inspect.stack()[1].frame
arg_info = inspect.getargvalues(caller_frame)
# Separate positional arguments and keyword arguments.
caller_locals = (
arg_info.locals
) # Dictionary of all local variables in the caller
self._ctor_kwargs = {
arg: caller_locals[arg] for arg in arg_info.args if arg != "self"
}
else:
self._ctor_kwargs = {
"input_observation_space": self.input_observation_space,
"input_action_space": self.input_action_space,
}

@OverrideToImplementCustomLogic
def recompute_output_observation_space(
self,
Expand Down Expand Up @@ -166,25 +183,7 @@ def __call__(
The new observation space (after data has passed through this ConnectorV2
piece).
"""
# Check, whether user is still overriding the old
# `recompute_observation_space_from_input_spaces()`.
parent_source = inspect.getsource(
ConnectorV2.recompute_observation_space_from_input_spaces
)
child_source = inspect.getsource(
self.recompute_observation_space_from_input_spaces
)
if parent_source == child_source:
return self.input_observation_space
else:
deprecation_warning(
old="ConnectorV2.recompute_observation_space_from_input_spaces()",
new="ConnectorV2.recompute_output_observation_space("
"input_observation_space: gym.Space, input_action_space: gym.Space) "
"-> gym.Space",
error=False,
)
return self.recompute_observation_space_from_input_spaces()
return self.input_observation_space

@OverrideToImplementCustomLogic
def recompute_output_action_space(
Expand Down Expand Up @@ -213,23 +212,7 @@ def recompute_output_action_space(
The new action space (after data has passed through this ConenctorV2
piece).
"""
# Check, whether user is still overriding the old
# `recompute_action_space_from_input_spaces()`.
parent_source = inspect.getsource(
ConnectorV2.recompute_action_space_from_input_spaces
)
child_source = inspect.getsource(self.recompute_action_space_from_input_spaces)
if parent_source == child_source:
return self.input_action_space
else:
deprecation_warning(
old="ConnectorV2.recompute_action_space_from_input_spaces()",
new="ConnectorV2.recompute_output_action_space("
"input_observation_space: gym.Space, input_action_space: gym.Space) "
"-> gym.Space",
error=False,
)
return self.recompute_action_space_from_input_spaces()
return self.input_action_space

@abc.abstractmethod
def __call__(
Expand Down Expand Up @@ -949,8 +932,8 @@ def set_state(self, state: StateDict) -> None:
@override(Checkpointable)
def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(self.input_observation_space, self.input_action_space), # *args
self._kwargs, # **kwargs
(), # *args
self._ctor_kwargs, # **kwargs
)

def reset_state(self) -> None:
Expand Down Expand Up @@ -1028,21 +1011,3 @@ def input_action_space(self, value):

def __str__(self, indentation: int = 0):
return " " * indentation + self.__class__.__name__

@Deprecated(
new="ConnectorV2.recompute_output_observation_space("
"input_observation_space: gym.Space, input_action_space: gym.Space) "
"-> gym.Space",
error=True,
)
def recompute_observation_space_from_input_spaces(self):
pass

@Deprecated(
new="ConnectorV2.recompute_action_observation_space("
"input_observation_space: gym.Space, input_action_space: gym.Space) "
"-> gym.Space",
error=True,
)
def recompute_action_space_from_input_spaces(self):
pass
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Any, Dict, List, Optional

import gymnasium as gym

from ray.rllib.core.columns import Columns
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
Expand Down Expand Up @@ -75,19 +73,6 @@ class AddNextObservationsFromEpisodesToTrainBatch(ConnectorV2):
)
"""

def __init__(
self,
input_observation_space: Optional[gym.Space] = None,
input_action_space: Optional[gym.Space] = None,
**kwargs,
):
"""Initializes a AddNextObservationsFromEpisodesToTrainBatch instance."""
super().__init__(
input_observation_space=input_observation_space,
input_action_space=input_action_space,
**kwargs,
)

@override(ConnectorV2)
def __call__(
self,
Expand Down
2 changes: 2 additions & 0 deletions rllib/connectors/learner/general_advantage_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.utils.annotations import override
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.postprocessing.value_predictions import compute_value_targets
from ray.rllib.utils.postprocessing.zero_padding import (
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
# vf targets) into tensors.
self._numpy_to_tensor_connector = None

@override(ConnectorV2)
def __call__(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,17 @@
"""
import os

from ray.rllib.connectors.env_to_module import (
EnvToModulePipeline,
AddObservationsFromEpisodesToBatch,
AddStatesFromEpisodesToBatch,
BatchIndividualItems,
NumpyToTensor,
from ray.rllib.connectors.env_to_module import EnvToModulePipeline
from ray.rllib.connectors.module_to_env import ModuleToEnvPipeline
from ray.rllib.core import (
COMPONENT_ENV_RUNNER,
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
DEFAULT_MODULE_ID,
)
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.rl_module import RLModule
Expand Down Expand Up @@ -144,10 +147,6 @@ def _env_creator(cfg):
if __name__ == "__main__":
args = parser.parse_args()

assert (
args.enable_new_api_stack
), "Must set --enable-new-api-stack when running this script!"

base_config = (
get_trainable_cls(args.algo)
.get_default_config()
Expand All @@ -163,50 +162,54 @@ def _env_creator(cfg):
print("Training LSTM-policy until desired reward/timesteps/iterations. ...")
results = run_rllib_example_script_experiment(base_config, args)

print("Training completed. Creating an env-loop for inference ...")

print("Env ...")
env = _env_creator(base_config.env_config)
# Get the last checkpoint from the above training run.
metric_key = metric = f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}"
best_result = results.get_best_result(metric=metric_key, mode="max")

# We build the env-to-module pipeline here manually, but feel also free to build it
# through the even easier:
# `env_to_module = base_config.build_env_to_module_connector(env=env)`, which will
# automatically add all default pieces necessary (for example the
# `AddStatesFromEpisodesToBatch` component b/c we are using a stateful RLModule
# here).
print("Env-to-module ConnectorV2 ...")
env_to_module = EnvToModulePipeline(
input_observation_space=env.observation_space,
input_action_space=env.action_space,
connectors=[
AddObservationsFromEpisodesToBatch(),
AddStatesFromEpisodesToBatch(),
BatchIndividualItems(multi_agent=args.num_agents > 0),
NumpyToTensor(),
],
print(
"Training completed (R="
f"{best_result.metrics[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}). "
"Creating an env-loop for inference ..."
)

# Create the RLModule.
# Get the last checkpoint from the above training run.
best_result = results.get_best_result(
metric=f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}", mode="max"
print("Env ...", end="")
env = _env_creator(base_config.env_config)
print(" ok")

# Create the env-to-module pipeline from the checkpoint.
print("Restore env-to-module connector from checkpoint ...", end="")
env_to_module = EnvToModulePipeline.from_checkpoint(
os.path.join(
best_result.checkpoint.path,
COMPONENT_ENV_RUNNER,
COMPONENT_ENV_TO_MODULE_CONNECTOR,
)
)
print(" ok")

print("Restore RLModule from checkpoint ...", end="")
# Create RLModule from a checkpoint.
rl_module = RLModule.from_checkpoint(
os.path.join(
best_result.checkpoint.path,
"learner_group",
"learner",
"rl_module",
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
DEFAULT_MODULE_ID,
)
)
print("RLModule restored ...")
print(" ok")

# For the module-to-env pipeline, we will use the convenient config utility.
print("Module-to-env ConnectorV2 ...")
module_to_env = base_config.build_module_to_env_connector(env=env)
print("Restore module-to-env connector from checkpoint ...", end="")
module_to_env = ModuleToEnvPipeline.from_checkpoint(
os.path.join(
best_result.checkpoint.path,
COMPONENT_ENV_RUNNER,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
)
)
print(" ok")

# Now our setup is complete:
# [gym.Env] -> env-to-module -> [RLModule] -> module-to-env -> [gym.Env] ... repeat
Expand Down Expand Up @@ -246,6 +249,7 @@ def _env_creator(cfg):
# is not vectorized here, so we need to use `action[0]`.
action = to_env.pop(Columns.ACTIONS)[0]
obs, reward, terminated, truncated, _ = env.step(action)
# Keep our `SingleAgentEpisode` instance updated at all times.
episode.add_env_step(
obs,
action,
Expand Down