From bff39da344d7642df1fe2488a84e8e3ed70902de Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 26 Jun 2020 09:04:44 +0000 Subject: [PATCH 01/48] chore(deps-dev): bump watchdog from 0.10.2 to 0.10.3 Bumps [watchdog](https://github.com/gorakhargosh/watchdog) from 0.10.2 to 0.10.3. - [Release notes](https://github.com/gorakhargosh/watchdog/releases) - [Changelog](https://github.com/gorakhargosh/watchdog/blob/master/changelog.rst) - [Commits](https://github.com/gorakhargosh/watchdog/compare/v0.10.2...v0.10.3) Signed-off-by: dependabot[bot] --- poetry.lock | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4e20abd7..34ef6590 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1959,7 +1959,7 @@ description = "Filesystem events monitoring" name = "watchdog" optional = false python-versions = "*" -version = "0.10.2" +version = "0.10.3" [package.dependencies] pathtools = ">=0.1.1" @@ -2028,7 +2028,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "07c41546f29a2277543366994ebd8a11e660438d3c515a08863b3af5bb90769e" +content-hash = "b618d227df18d652560cbdb25aea9452000d671d5d87c679dfdb45f4f208b033" python-versions = "^3.7" [metadata.files] @@ -3110,7 +3110,7 @@ virtualenv = [ {file = "virtualenv-20.0.23.tar.gz", hash = "sha256:5102fbf1ec57e80671ef40ed98a84e980a71194cedf30c87c2b25c3a9e0b0107"}, ] watchdog = [ - {file = "watchdog-0.10.2.tar.gz", hash = "sha256:c560efb643faed5ef28784b2245cf8874f939569717a4a12826a173ac644456b"}, + {file = "watchdog-0.10.3.tar.gz", hash = "sha256:4214e1379d128b0588021880ccaf40317ee156d4603ac388b9adcf29165e0c04"}, ] wcwidth = [ {file = "wcwidth-0.2.4-py2.py3-none-any.whl", hash = "sha256:79375666b9954d4a1a10739315816324c3e73110af9d0e102d906fdb0aec009f"}, diff --git a/pyproject.toml b/pyproject.toml index e72e0cdd..fad2097d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ opencv-python = "^4.2.0" [tool.poetry.dev-dependencies] flake8 = "^3.8.3" pylint = "^2.5.3" -watchdog = "^0.10.2" +watchdog = "^0.10.3" black = "^19.10b0" tox = "^3.15.2" sphinx = "^3.1.1" From f95068bfb7b43d90c548025005399528c54b7021 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jun 2020 09:04:00 +0000 Subject: [PATCH 02/48] chore(deps): bump cachetools from 4.1.0 to 4.1.1 Bumps [cachetools](https://github.com/tkem/cachetools) from 4.1.0 to 4.1.1. - [Release notes](https://github.com/tkem/cachetools/releases) - [Changelog](https://github.com/tkem/cachetools/blob/master/CHANGELOG.rst) - [Commits](https://github.com/tkem/cachetools/compare/v4.1.0...v4.1.1) Signed-off-by: dependabot[bot] --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4e20abd7..3ae4867a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -298,7 +298,7 @@ description = "Extensible memoizing collections and decorators" name = "cachetools" optional = false python-versions = "~=3.5" -version = "4.1.0" +version = "4.1.1" [[package]] category = "main" @@ -2145,8 +2145,8 @@ cached-property = [ {file = "cached_property-1.5.1-py2.py3-none-any.whl", hash = "sha256:3a026f1a54135677e7da5ce819b0c690f156f37976f3e30c5430740725203d7f"}, ] cachetools = [ - {file = "cachetools-4.1.0-py3-none-any.whl", hash = "sha256:de5d88f87781602201cde465d3afe837546663b168e8b39df67411b0bf10cefc"}, - {file = "cachetools-4.1.0.tar.gz", hash = "sha256:1d057645db16ca7fe1f3bd953558897603d6f0b9c51ed9d11eb4d071ec4e2aab"}, + {file = "cachetools-4.1.1-py3-none-any.whl", hash = "sha256:513d4ff98dd27f85743a8dc0e92f55ddb1b49e060c2d5961512855cda2c01a98"}, + {file = "cachetools-4.1.1.tar.gz", hash = "sha256:bbaa39c3dede00175df2dc2b03d0cf18dd2d32a7de7beb68072d13043c9edb20"}, ] certifi = [ {file = "certifi-2020.6.20-py2.py3-none-any.whl", hash = "sha256:8fc0819f1f30ba15bdb34cceffb9ef04d99f420f68eb75d901e9560b8749fc41"}, From f2638af0a00ce699a46a3115ec1a1b5f1f65f13f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jun 2020 22:10:13 +0000 Subject: [PATCH 03/48] chore(deps-dev): bump ipython from 7.15.0 to 7.16.1 Bumps [ipython](https://github.com/ipython/ipython) from 7.15.0 to 7.16.1. - [Release notes](https://github.com/ipython/ipython/releases) - [Commits](https://github.com/ipython/ipython/compare/7.15.0...7.16.1) Signed-off-by: dependabot[bot] --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index bc7fbdf0..0b4ce09b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -619,7 +619,7 @@ description = "IPython: Productive Interactive Computing" name = "ipython" optional = false python-versions = ">=3.6" -version = "7.15.0" +version = "7.16.1" [package.dependencies] appnope = "*" @@ -2028,7 +2028,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "b618d227df18d652560cbdb25aea9452000d671d5d87c679dfdb45f4f208b033" +content-hash = "0f2608e43534071b40ebce1fb702ef9fdc0f2533eced20f9cf186d38285895ab" python-versions = "^3.7" [metadata.files] @@ -2337,8 +2337,8 @@ ipykernel = [ {file = "ipykernel-5.3.0.tar.gz", hash = "sha256:731adb3f2c4ebcaff52e10a855ddc87670359a89c9c784d711e62d66fccdafae"}, ] ipython = [ - {file = "ipython-7.15.0-py3-none-any.whl", hash = "sha256:1b85d65632211bf5d3e6f1406f3393c8c429a47d7b947b9a87812aa5bce6595c"}, - {file = "ipython-7.15.0.tar.gz", hash = "sha256:0ef1433879816a960cd3ae1ae1dc82c64732ca75cec8dab5a4e29783fb571d0e"}, + {file = "ipython-7.16.1-py3-none-any.whl", hash = "sha256:2dbcc8c27ca7d3cfe4fcdff7f45b27f9a8d3edfa70ff8024a71c7a8eb5f09d64"}, + {file = "ipython-7.16.1.tar.gz", hash = "sha256:9f4fcb31d3b2c533333893b9172264e4821c1ac91839500f31bd43f2c59b3ccf"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, diff --git a/pyproject.toml b/pyproject.toml index fad2097d..42e329a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ pre-commit = "^2.5.1" reorder-python-imports = "^2.3.1" mypy = "^0.782" coverage = "^5.1" -ipython = "^7.15.0" +ipython = "^7.16.1" poetry-version = "^0.1.5" pytest-mock = "^3.1.1" pytest-sugar = "^0.9.3" From 7a54442fc2e30b9f19b5568c0d4b7e16ec5ad076 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Jun 2020 22:22:13 +0000 Subject: [PATCH 04/48] chore(deps-dev): bump tox from 3.15.2 to 3.16.1 Bumps [tox](https://github.com/tox-dev/tox) from 3.15.2 to 3.16.1. - [Release notes](https://github.com/tox-dev/tox/releases) - [Changelog](https://github.com/tox-dev/tox/blob/master/docs/changelog.rst) - [Commits](https://github.com/tox-dev/tox/compare/3.15.2...3.16.1) Signed-off-by: dependabot[bot] --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 0b4ce09b..fc47babf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1828,7 +1828,7 @@ description = "tox is a generic virtualenv management and test command line tool name = "tox" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" -version = "3.15.2" +version = "3.16.1" [package.dependencies] colorama = ">=0.4.1" @@ -2028,7 +2028,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "0f2608e43534071b40ebce1fb702ef9fdc0f2533eced20f9cf186d38285895ab" +content-hash = "a3ba29e6b9d210c4f16c69d0acb0470c21c10a039325b78cba7e1de32d0c231c" python-versions = "^3.7" [metadata.files] @@ -3054,8 +3054,8 @@ tornado = [ {file = "tornado-6.0.4.tar.gz", hash = "sha256:0fe2d45ba43b00a41cd73f8be321a44936dc1aba233dee979f17a042b83eb6dc"}, ] tox = [ - {file = "tox-3.15.2-py2.py3-none-any.whl", hash = "sha256:50a188b8e17580c1fb931f494a754e6507d4185f54fb18aca5ba3e12d2ffd55e"}, - {file = "tox-3.15.2.tar.gz", hash = "sha256:c696d36cd7c6a28ada2da780400e44851b20ee19ef08cfe73344a1dcebbbe9f3"}, + {file = "tox-3.16.1-py2.py3-none-any.whl", hash = "sha256:60c3793f8ab194097ec75b5a9866138444f63742b0f664ec80be1222a40687c5"}, + {file = "tox-3.16.1.tar.gz", hash = "sha256:9a746cda9cadb9e1e05c7ab99f98cfcea355140d2ecac5f97520be94657c3bc7"}, ] traitlets = [ {file = "traitlets-4.3.3-py2.py3-none-any.whl", hash = "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44"}, diff --git a/pyproject.toml b/pyproject.toml index 42e329a6..484a9570 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ flake8 = "^3.8.3" pylint = "^2.5.3" watchdog = "^0.10.3" black = "^19.10b0" -tox = "^3.15.2" +tox = "^3.16.1" sphinx = "^3.1.1" pytest = "^5.4.3" gym-cartpole-swingup = "^0.1.0" From 3da6ddcee919f5e36ff8fd1a17755d84c90c0f33 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 30 Jun 2020 09:03:09 +0000 Subject: [PATCH 05/48] chore(deps): bump streamlit from 0.62.0 to 0.62.1 Bumps [streamlit](https://github.com/streamlit/streamlit) from 0.62.0 to 0.62.1. - [Release notes](https://github.com/streamlit/streamlit/releases) - [Changelog](https://github.com/streamlit/streamlit/blob/develop/docs/changelog.md) - [Commits](https://github.com/streamlit/streamlit/compare/0.62.0...0.62.1) Signed-off-by: dependabot[bot] --- poetry.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index fc47babf..c06d692a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1688,7 +1688,7 @@ description = "Frontend library for machine learning engineers" name = "streamlit" optional = false python-versions = ">=3.6" -version = "0.62.0" +version = "0.62.1" [package.dependencies] altair = ">=3.2.0" @@ -2997,7 +2997,7 @@ sphinxcontrib-serializinghtml = [ {file = "sphinxcontrib_serializinghtml-1.1.4-py2.py3-none-any.whl", hash = "sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a"}, ] streamlit = [ - {file = "streamlit-0.62.0-py2.py3-none-any.whl", hash = "sha256:114403a3f10885979744eb4cd5d9dccd4f3f606be55dd81baf6f0b2f47e4de06"}, + {file = "streamlit-0.62.1-py2.py3-none-any.whl", hash = "sha256:5b5e5219b103276bd2ad6005659de9381ed044e116bbaf5bdac7da5424f1a309"}, ] stringcase = [ {file = "stringcase-1.2.0.tar.gz", hash = "sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"}, From aac6f03524de203c0e1e36b5da8858cf57bef4d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Thu, 25 Jun 2020 19:33:38 -0300 Subject: [PATCH 06/48] test(mapo): use mocker fixture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- tests/agents/mapo/test_policy.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/agents/mapo/test_policy.py b/tests/agents/mapo/test_policy.py index 0bc5f025..19352d3e 100644 --- a/tests/agents/mapo/test_policy.py +++ b/tests/agents/mapo/test_policy.py @@ -1,6 +1,5 @@ # pylint: disable=missing-docstring,redefined-outer-name,protected-access import copy -from unittest import mock import numpy as np import pytest @@ -79,11 +78,11 @@ def test_learn_on_batch(policy, sample_batch): assert all(not torch.allclose(o, n) for o, n in zip(old_params, new_params)) -def test_compile(policy): - with mock.patch("raylab.losses.MAPO.compile") as mapo, mock.patch( - "raylab.losses.SPAML.compile" - ) as spaml: - policy.compile() - assert isinstance(policy.module, torch.jit.ScriptModule) - assert mapo.called - assert spaml.called +def test_compile(policy, mocker): + mapo = mocker.patch("raylab.losses.MAPO.compile") + spaml = mocker.patch("raylab.losses.SPAML.compile") + + policy.compile() + assert isinstance(policy.module, torch.jit.ScriptModule) + assert mapo.called + assert spaml.called From cff173f5c2a062464a9aafdd8d678f61a7fd79ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Fri, 26 Jun 2020 15:55:58 -0300 Subject: [PATCH 07/48] feat(modules): add DDPG network MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/ddpg.py | 207 ++++++++++++++++++ raylab/modules/networks/action_value.py | 138 ++++++++++++ raylab/modules/networks/policy/__init__.py | 1 + .../modules/networks/policy/deterministic.py | 164 ++++++++++++++ raylab/pytorch/nn/init.py | 2 +- raylab/utils/annotations.py | 2 + 6 files changed, 513 insertions(+), 1 deletion(-) create mode 100644 raylab/modules/ddpg.py create mode 100644 raylab/modules/networks/action_value.py create mode 100644 raylab/modules/networks/policy/__init__.py create mode 100644 raylab/modules/networks/policy/deterministic.py diff --git a/raylab/modules/ddpg.py b/raylab/modules/ddpg.py new file mode 100644 index 00000000..f90dd565 --- /dev/null +++ b/raylab/modules/ddpg.py @@ -0,0 +1,207 @@ +"""NN architecture used in Deep Deterministic Policy Gradients.""" +from dataclasses import dataclass +from dataclasses import field +from typing import List +from typing import Tuple + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box +from ray.rllib import SampleBatch +from torch import Tensor + +from raylab.utils.annotations import TensorDict + +from .networks.action_value import ForkedQValueEnsemble +from .networks.action_value import MLPQValue +from .networks.action_value import QValueEnsemble +from .networks.action_value import StateActionMLPSpec +from .networks.policy.deterministic import DeterministicPolicy +from .networks.policy.deterministic import MLPDeterministicPolicy +from .networks.policy.deterministic import StateMLPSpec + + +@dataclass +class DDPGActorSpec(DataClassJsonMixin): + """Specifications for policy, behavior, and target policy. + + Args: + encoder: Specifications for creating the multilayer perceptron mapping + states to pre-action linear features + norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't + normalize actions before squashing function + behavior: Type of behavior policy. Either 'gaussian', 'parameter_noise', + or 'deterministic' + smooth_target_policy: Whether to use a noisy target policy for + Q-Learning + target_gaussian_sigma: Gaussian standard deviation for noisy target + policy + separate_target_policy: Whether to use separate parameters for the + target policy. Intended for use with polyak averaging + """ + + encoder: StateMLPSpec = field(default_factory=StateMLPSpec) + norm_beta: float = 1.2 + behavior: str = "gaussian" + smooth_target_policy: bool = True + target_gaussian_sigma: float = 0.3 + separate_target_policy: bool = False + + def __post_init__(self): + cls_name = type(self).__name__ + assert self.norm_beta > 0, f"{cls_name}.norm_beta must be positive" + valid_behaviors = {"gaussian", "parameter_noise", "deterministic"} + assert ( + self.behavior in valid_behaviors + ), f"{cls_name}.behavior must be one of {valid_behaviors}" + assert ( + self.target_gaussian_sigma > 0 + ), f"{cls_name}.target_gaussian_sigma must be positive" + + +@dataclass +class DDPGCriticSpec(DataClassJsonMixin): + """Specifications for action-value estimators. + + Args: + encoder: Specifications for creating the multilayer perceptron mapping + states and actions to pre-value function linear features + double_q: Whether to create two Q-value estimators instead of one. + Defaults to True + parallelize: Whether to evaluate Q-values in parallel. Defaults to + False. + """ + + encoder: StateActionMLPSpec = field(default_factory=StateActionMLPSpec) + double_q: bool = True + parallelize: bool = False + + +@dataclass +class DDPGSpec(DataClassJsonMixin): + """Specifications for DDPG modules. + + Args: + actor: Specifications for policy, behavior, and target policy + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments.configuration dictionary for parameter + """ + + actor: DDPGActorSpec = field(default_factory=DDPGActorSpec) + critic: DDPGCriticSpec = field(default_factory=DDPGCriticSpec) + initializer: dict = field(default_factory=dict) + + +class DDPG(nn.Module): + """NN module for DDPG-like algorithms. + + Since it is common to use clipped double Q-Learning, critic is implemented as + a ModuleList of action-value functions. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function for the parameters. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for DDPG modules + + Attributes: + actor: The deterministic policy to be learned + behavior: The policy for exploration + target_actor: The policy used for estimating the arg max in Q-Learning + critics: The action-value estimators to be learned + target_critics: The action-value estimators used for bootstrapping in + Q-Learning + forward_batch_keys: Keys in the input tensor dict that will be accessed + in the main forward pass. Useful for the caller to convert the + necessary inputs to tensors + """ + + actor: DeterministicPolicy + behavior: DeterministicPolicy + target_actor: DeterministicPolicy + critics: nn.ModuleList + target_critics: nn.ModuleList + forward_batch_keys: Tuple[str] = (SampleBatch.CUR_OBS,) + + def __init__(self, obs_space: Box, action_space: Box, spec: DDPGSpec): + super().__init__() + # Build actor + self.actor, self.behavior, self.target_actor = self._make_actor( + obs_space, action_space, spec.actor, spec.initializer + ) + + # Build critic + self.critics, self.target_critics = self._make_critic( + obs_space, action_space, spec.critic, spec.initializer + ) + + def forward( + self, input_dict: TensorDict, state: List[Tensor], seq_lens: Tensor + ) -> Tuple[TensorDict, List[Tensor]]: + """Maps input tensors to action distribution parameters. + + Args: + input_dict: Tensor dictionary with mandatory `forward_batch_keys` + contained within + state: List of RNN state tensors + seq_lens: 1D tensor holding input sequence lengths + + Returns: + A tuple containg an input dictionary to the policy's `dist_class` + and a list of RNN state tensors + """ + # pylint:disable=unused-argument,arguments-differ + return {"obs": input_dict["obs"]}, state + + @staticmethod + def _make_actor( + obs_space: Box, action_space: Box, spec: DDPGActorSpec, initializer_spec: dict + ) -> Tuple[MLPDeterministicPolicy, MLPDeterministicPolicy, MLPDeterministicPolicy]: + def make_policy(): + return MLPDeterministicPolicy( + obs_space, action_space, spec.actor.encoder, spec.actor.norm_beta + ) + + actor = make_policy() + actor.initialize_parameters(initializer_spec) + + behavior = actor + if spec.behavior == "parameter_noise": + behavior = make_policy() + behavior.load_state_dict(actor.state_dict()) + + target_actor = actor + if spec.separate_target_policy: + target_actor = make_policy() + target_actor.load_state_dict(actor.state_dict()) + if spec.smooth_target_policy: + target_actor = DeterministicPolicy.add_gaussian_noise( + target_actor, noise_stddev=spec.target_gaussian_sigma + ) + + return actor, behavior, target_actor + + @staticmethod + def _make_critic( + obs_space: Box, action_space: Box, spec: DDPGCriticSpec, initializer_spec: dict + ) -> Tuple[QValueEnsemble]: + def make_critic(): + return MLPQValue(obs_space, action_space, spec.critic.encoder) + + def make_critic_ensemble(): + n_critics = 2 if spec.critic.double_q else 1 + critics = [make_critic() for _ in range(n_critics)] + + if spec.critic.parallelize: + return ForkedQValueEnsemble(critics) + return QValueEnsemble(critics) + + critics = make_critic_ensemble() + critics.initialize_parameters(initializer_spec) + target_critics = make_critic_ensemble() + target_critics.load_state_dict(critics) + return critics, target_critics diff --git a/raylab/modules/networks/action_value.py b/raylab/modules/networks/action_value.py new file mode 100644 index 00000000..baaf86fd --- /dev/null +++ b/raylab/modules/networks/action_value.py @@ -0,0 +1,138 @@ +"""Parameterized action-value estimators.""" +from dataclasses import dataclass +from typing import List +from typing import Optional + +import torch +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box +from torch import Tensor + +import raylab.pytorch.nn as nnx +from raylab.pytorch.nn.init import initialize_ + + +class QValue(nn.Module): + """Neural network module emulating a Q value function. + + Args: + encoder: NN module mapping states to 1D features. Must have an + `out_features` attribute with the size of the output features + """ + + def __init__(self, encoder: nn.Module): + super().__init__() + self.encoder = encoder + self.value_linear = nn.Linear(self.encoder.out_features, 1) + + def forward(self, obs: Tensor, action: Tensor) -> Tensor: + """Main forward pass mapping obs and actions to Q-values. + + Note: + The output tensor has a last singleton dimension, i.e., for a batch + of 10 obs-action pairs, the output will have shape (10, 1). + """ + # pylint:disable=arguments-differ + features = self.encoder(obs, action) + return self.value_linear(features) + + +@dataclass +class StateActionMLPSpec(DataClassJsonMixin): + """Specifications for building an MLP with state and action inputs. + + Args: + units: Number of units in each hidden layer + activation: Nonlinearity following each linear layer + delay_action: Whether to apply an initial preprocessing layer on the + observation before concatenating the action to the input. + """ + + units: List[int] + activation: Optional[str] + delay_action: bool + + +class MLPQValue(QValue): + """Q-value function with a multilayer perceptron encoder. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Multilayer perceptron specifications + """ + + def __init__(self, obs_space: Box, action_space: Box, mlp_spec: StateActionMLPSpec): + obs_size = obs_space.shape[0] + action_size = action_space.shape[0] + + encoder = nnx.StateActionEncoder( + obs_size, + action_size, + units=mlp_spec.units, + activation=mlp_spec.activation, + delay_action=mlp_spec.delay_action, + ) + super().__init__(encoder) + self.mlp_spec = mlp_spec.activation + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + initializer = initialize_( + activation=self.mlp_spec.activation, **initializer_spec + ) + self.encoder.apply(initializer) + + +class QValueEnsemble(nn.ModuleList): + """A static list of Q-value estimators. + + Args: + q_values: A list of QValue modules + """ + + def __init__(self, q_values): + assert all( + isinstance(q, QValue) for q in q_values + ), """All modules in QValueEnsemble must be instances of QValue.""" + super().__init__(q_values) + + def forward(self, obs: Tensor, action: Tensor, clip: bool = False) -> Tensor: + """Evaluate each Q estimator in the ensemble. + + Args: + obs: The observation tensor + action: The action tensor + clip: Whether to output the minimum of the action-values. Preserves + output dimensions + + Returns: + A tensor of shape `(*, N)`, where `N` is the ensemble size + """ + # pylint:disable=arguments-differ + action_values = torch.cat([m(obs, action) for m in self], dim=-1) + if clip: + action_values, _ = action_values.min(keepdim=True, dim=-1) + return action_values + + +class ForkedQValueEnsemble(QValueEnsemble): + """Ensemble of Q-value estimators with parallelized forward pass.""" + + def forward(self, obs: Tensor, action: Tensor, clip: bool = False) -> Tensor: + # pylint:disable=protected-access + futures = [torch.jit._fork(m, (obs, action)) for m in self] + action_values = torch.cat([torch.jit._wait(f) for f in futures], dim=-1) + if clip: + action_values, _ = action_values.min(keepdim=True, dim=-1) + return action_values diff --git a/raylab/modules/networks/policy/__init__.py b/raylab/modules/networks/policy/__init__.py new file mode 100644 index 00000000..2964694e --- /dev/null +++ b/raylab/modules/networks/policy/__init__.py @@ -0,0 +1 @@ +"""Policies as neural network modules.""" diff --git a/raylab/modules/networks/policy/deterministic.py b/raylab/modules/networks/policy/deterministic.py new file mode 100644 index 00000000..5bf40747 --- /dev/null +++ b/raylab/modules/networks/policy/deterministic.py @@ -0,0 +1,164 @@ +"""Parameterized deterministic policies.""" +import warnings +from dataclasses import dataclass +from typing import Callable +from typing import List +from typing import Optional + +import torch +import torch.nn as nn +from gym.spaces import Box +from torch import Tensor + +import raylab.pytorch.nn as nnx +from raylab.pytorch.nn.init import initialize_ + + +class DeterministicPolicy(nn.Module): + """Continuous action deterministic policy as a sequence of modules. + + If a noise module is passed, it is evaluated on unconstrained actions before + the squashing module. + + Args: + encoder: NN module mapping states to 1D features + action_linear: Linear module mapping features to unconstrained actions + squashing: Invertible module mapping unconstrained actions to bounded + action space + noise: Optional stochastic module adding noise to unconstrained actions + """ + + def __init__( + self, + encoder: nn.Module, + action_linear: nn.Module, + squashing: nn.Module, + noise: Optional[nn.Module] = None, + ): + super().__init__() + self.encoder = encoder + self.action_linear = action_linear + self.squashing = squashing + self.noise = noise + + def forward(self, obs: Tensor) -> Tensor: # pylint:disable=arguments-differ + """Main forward pass mapping observations to actions.""" + unconstrained_action = self.unconstrained_action(obs) + return self.squashing(unconstrained_action) + + @torch.jit.export + def unconstrained_action(self, obs: Tensor) -> Tensor: + """Forward pass with no squashing at the end""" + features = self.encoder(obs) + unconstrained_action = self.action_linear(features) + if self.noise: + unconstrained_action = self.noise(unconstrained_action) + return unconstrained_action + + @torch.jit.export + def unsquash_action(self, action: Tensor) -> Tensor: + """Returns the unconstrained action which generated the given action.""" + return self.squashing(action, reverse=True) + + def initialize_(self, initializer: Callable[[nn.Module], None]): + """Apply initializer to encoder layers.""" + self.encoder.apply(initializer) + + @classmethod + def add_gaussian_noise(cls, policy, noise_stddev: float): + """Adds a zero-mean Gaussian noise module to a DeterministicPolicy. + + Args: + policy: The deterministic policy. + noise_stddev: Standard deviation of the Gaussian noise + + Returns: + A deterministic policy sharing all paremeters with the input one and + with additional noise module before squashing. + """ + if policy.noise is not None: + warnings.warn( + "Adding Gaussian noise to already noisy policy. Are you sure you" + " called `add_gaussian_noise` on the right policy?" + ) + noise = nnx.GaussianNoise(noise_stddev) + return cls(policy.encoder, policy.action_linear, policy.squashing, noise=noise) + + +@dataclass +class StateMLPSpec: + """Specifications for creating a multilayer perceptron. + + Args: + units: Number of units in each hidden layer + activation: Nonlinearity following each linear layer + layer_norm: Whether to apply layer normalization between each linear layer + and following activation + """ + + units: List[int] + activation: Optional[str] + layer_norm: bool + + +class MLPDeterministicPolicy(DeterministicPolicy): + """DeterministicPolicy with multilayer perceptron encoder. + + The final Linear layer is initialized so that actions are near the origin + point. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Multilayer perceptron specifications + norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't + normalize actions before squashing function. + """ + + def __init__( + self, + obs_space: Box, + action_space: Box, + mlp_spec: StateMLPSpec, + norm_beta: float, + ): + obs_size = obs_space.shape[0] + action_size = action_space.shape[0] + action_low, action_high = map( + torch.as_tensor, (action_space.low, action_space.high) + ) + + encoder = nnx.FullyConnected( + obs_size, + mlp_spec.units, + mlp_spec.activation, + layer_norm=mlp_spec.layer_norm, + ) + + if norm_beta: + action_linear = nnx.NormalizedLinear( + encoder.out_features, action_size, beta=norm_beta + ) + else: + action_linear = nn.Linear(encoder.out_features, action_size) + + squash = nnx.TanhSquash(action_low, action_high) + + super().__init__(encoder, action_linear, squash) + self.mlp_spec = mlp_spec + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + initializer = initialize_( + activation=self.mlp_spec.activation, **initializer_spec + ) + self.encoder.apply(initializer) diff --git a/raylab/pytorch/nn/init.py b/raylab/pytorch/nn/init.py index 3f1f43c7..d4168ce2 100644 --- a/raylab/pytorch/nn/init.py +++ b/raylab/pytorch/nn/init.py @@ -31,7 +31,7 @@ def initialize_(name, activation=None, **options): """Return a callable to apply an initializer with the given name and options. If `gain` is part of the initializer's argspec and is not specified in options, - the recommended value from `nn.init.calculate_gain` is used. + the recommended value from `torch.nn.init.calculate_gain` is used. Arguments: name (str): name of initializer function diff --git a/raylab/utils/annotations.py b/raylab/utils/annotations.py index 1a124fe0..c19b5e02 100644 --- a/raylab/utils/annotations.py +++ b/raylab/utils/annotations.py @@ -1,9 +1,11 @@ """Collection of type annotations.""" from typing import Callable +from typing import Dict from typing import Tuple from torch import Tensor +TensorDict = Dict[str, Tensor] RewardFn = Callable[[Tensor, Tensor, Tensor], Tensor] TerminationFn = Callable[[Tensor, Tensor, Tensor], Tensor] DynamicsFn = Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]] From 1f19ec1625e5f1d4b4731b508ee2072707e2e51d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 08:49:49 -0300 Subject: [PATCH 08/48] feat(modules): add refactored DDPG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/catalog.py | 11 +- raylab/modules/ddpg.py | 183 +++--------------- raylab/modules/networks/actor/__init__.py | 0 .../modules/networks/actor/deterministic.py | 109 +++++++++++ .../modules/networks/actor/policy/__init__.py | 0 .../{ => actor}/policy/deterministic.py | 10 +- raylab/modules/networks/critic/__init__.py | 0 .../modules/networks/critic/action_value.py | 81 ++++++++ .../{action_value.py => critic/q_value.py} | 8 +- raylab/modules/networks/policy/__init__.py | 1 - 10 files changed, 235 insertions(+), 168 deletions(-) create mode 100644 raylab/modules/networks/actor/__init__.py create mode 100644 raylab/modules/networks/actor/deterministic.py create mode 100644 raylab/modules/networks/actor/policy/__init__.py rename raylab/modules/networks/{ => actor}/policy/deterministic.py (96%) create mode 100644 raylab/modules/networks/critic/__init__.py create mode 100644 raylab/modules/networks/critic/action_value.py rename raylab/modules/networks/{action_value.py => critic/q_value.py} (96%) delete mode 100644 raylab/modules/networks/policy/__init__.py diff --git a/raylab/modules/catalog.py b/raylab/modules/catalog.py index 880fa0c6..99c0e7aa 100644 --- a/raylab/modules/catalog.py +++ b/raylab/modules/catalog.py @@ -1,5 +1,6 @@ """Registry of modules for PyTorch policies.""" +from .ddpg import DDPG from .ddpg_module import DDPGModule from .maxent_model_based import MaxEntModelBased from .model_based_ddpg import ModelBasedDDPG @@ -32,9 +33,15 @@ "SVGRealNVPActor": SVGRealNVPActor, } +MODULESv2 = {k.__name__: k for k in [DDPG]} + def get_module(obs_space, action_space, config): """Retrieve and construct module of given name.""" type_ = config.pop("type") - module = MODULES[type_](obs_space, action_space, config) - return module + if type_ in MODULES: + return MODULES[type_](obs_space, action_space, config) + + cls = MODULESv2[type_] + spec = cls.spec_cls.from_dict(config) + return cls(obs_space, action_space, spec) diff --git a/raylab/modules/ddpg.py b/raylab/modules/ddpg.py index f90dd565..5dabe9a3 100644 --- a/raylab/modules/ddpg.py +++ b/raylab/modules/ddpg.py @@ -1,80 +1,18 @@ """NN architecture used in Deep Deterministic Policy Gradients.""" from dataclasses import dataclass from dataclasses import field -from typing import List -from typing import Tuple import torch.nn as nn from dataclasses_json import DataClassJsonMixin from gym.spaces import Box -from ray.rllib import SampleBatch -from torch import Tensor -from raylab.utils.annotations import TensorDict +from .networks.actor.deterministic import DeterministicActor +from .networks.actor.policy.deterministic import DeterministicPolicy +from .networks.critic.action_value import ActionValueCritic +from .networks.critic.q_value import QValueEnsemble -from .networks.action_value import ForkedQValueEnsemble -from .networks.action_value import MLPQValue -from .networks.action_value import QValueEnsemble -from .networks.action_value import StateActionMLPSpec -from .networks.policy.deterministic import DeterministicPolicy -from .networks.policy.deterministic import MLPDeterministicPolicy -from .networks.policy.deterministic import StateMLPSpec - - -@dataclass -class DDPGActorSpec(DataClassJsonMixin): - """Specifications for policy, behavior, and target policy. - - Args: - encoder: Specifications for creating the multilayer perceptron mapping - states to pre-action linear features - norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't - normalize actions before squashing function - behavior: Type of behavior policy. Either 'gaussian', 'parameter_noise', - or 'deterministic' - smooth_target_policy: Whether to use a noisy target policy for - Q-Learning - target_gaussian_sigma: Gaussian standard deviation for noisy target - policy - separate_target_policy: Whether to use separate parameters for the - target policy. Intended for use with polyak averaging - """ - - encoder: StateMLPSpec = field(default_factory=StateMLPSpec) - norm_beta: float = 1.2 - behavior: str = "gaussian" - smooth_target_policy: bool = True - target_gaussian_sigma: float = 0.3 - separate_target_policy: bool = False - - def __post_init__(self): - cls_name = type(self).__name__ - assert self.norm_beta > 0, f"{cls_name}.norm_beta must be positive" - valid_behaviors = {"gaussian", "parameter_noise", "deterministic"} - assert ( - self.behavior in valid_behaviors - ), f"{cls_name}.behavior must be one of {valid_behaviors}" - assert ( - self.target_gaussian_sigma > 0 - ), f"{cls_name}.target_gaussian_sigma must be positive" - - -@dataclass -class DDPGCriticSpec(DataClassJsonMixin): - """Specifications for action-value estimators. - - Args: - encoder: Specifications for creating the multilayer perceptron mapping - states and actions to pre-value function linear features - double_q: Whether to create two Q-value estimators instead of one. - Defaults to True - parallelize: Whether to evaluate Q-values in parallel. Defaults to - False. - """ - - encoder: StateActionMLPSpec = field(default_factory=StateActionMLPSpec) - double_q: bool = True - parallelize: bool = False +ActorSpec = DeterministicActor.spec_cls +CriticSpec = ActionValueCritic.spec_cls @dataclass @@ -86,23 +24,17 @@ class DDPGSpec(DataClassJsonMixin): critic: Specifications for action-value estimators initializer: Optional dictionary with mandatory `type` key corresponding to the initializer function name in `torch.nn.init` and optional - keyword arguments.configuration dictionary for parameter + keyword arguments. """ - actor: DDPGActorSpec = field(default_factory=DDPGActorSpec) - critic: DDPGCriticSpec = field(default_factory=DDPGCriticSpec) + actor: ActorSpec = field(default_factory=ActorSpec) + critic: CriticSpec = field(default_factory=CriticSpec) initializer: dict = field(default_factory=dict) class DDPG(nn.Module): """NN module for DDPG-like algorithms. - Since it is common to use clipped double Q-Learning, critic is implemented as - a ModuleList of action-value functions. - - Uses `raylab.pytorch.nn.init.initialize_` to create an initializer - function for the parameters. - Args: obs_space: Observation space action_space: Action space @@ -115,93 +47,32 @@ class DDPG(nn.Module): critics: The action-value estimators to be learned target_critics: The action-value estimators used for bootstrapping in Q-Learning - forward_batch_keys: Keys in the input tensor dict that will be accessed - in the main forward pass. Useful for the caller to convert the - necessary inputs to tensors + spec_cls: Expected class of `spec` init argument """ + # pylint:disable=abstract-method actor: DeterministicPolicy behavior: DeterministicPolicy target_actor: DeterministicPolicy - critics: nn.ModuleList - target_critics: nn.ModuleList - forward_batch_keys: Tuple[str] = (SampleBatch.CUR_OBS,) + critics: QValueEnsemble + target_critics: QValueEnsemble + spec_cls = DDPGSpec def __init__(self, obs_space: Box, action_space: Box, spec: DDPGSpec): super().__init__() + # Top-level initializer options take precedence over individual + # component's options + if spec.initializer: + spec.actor.initializer = spec.initializer + spec.critic.initializer = spec.initializer + # Build actor - self.actor, self.behavior, self.target_actor = self._make_actor( - obs_space, action_space, spec.actor, spec.initializer - ) + actor = DeterministicActor(obs_space, action_space, spec.actor) + self.actor = actor.policy + self.behavior = actor.behavior + self.target_actor = actor.target_policy # Build critic - self.critics, self.target_critics = self._make_critic( - obs_space, action_space, spec.critic, spec.initializer - ) - - def forward( - self, input_dict: TensorDict, state: List[Tensor], seq_lens: Tensor - ) -> Tuple[TensorDict, List[Tensor]]: - """Maps input tensors to action distribution parameters. - - Args: - input_dict: Tensor dictionary with mandatory `forward_batch_keys` - contained within - state: List of RNN state tensors - seq_lens: 1D tensor holding input sequence lengths - - Returns: - A tuple containg an input dictionary to the policy's `dist_class` - and a list of RNN state tensors - """ - # pylint:disable=unused-argument,arguments-differ - return {"obs": input_dict["obs"]}, state - - @staticmethod - def _make_actor( - obs_space: Box, action_space: Box, spec: DDPGActorSpec, initializer_spec: dict - ) -> Tuple[MLPDeterministicPolicy, MLPDeterministicPolicy, MLPDeterministicPolicy]: - def make_policy(): - return MLPDeterministicPolicy( - obs_space, action_space, spec.actor.encoder, spec.actor.norm_beta - ) - - actor = make_policy() - actor.initialize_parameters(initializer_spec) - - behavior = actor - if spec.behavior == "parameter_noise": - behavior = make_policy() - behavior.load_state_dict(actor.state_dict()) - - target_actor = actor - if spec.separate_target_policy: - target_actor = make_policy() - target_actor.load_state_dict(actor.state_dict()) - if spec.smooth_target_policy: - target_actor = DeterministicPolicy.add_gaussian_noise( - target_actor, noise_stddev=spec.target_gaussian_sigma - ) - - return actor, behavior, target_actor - - @staticmethod - def _make_critic( - obs_space: Box, action_space: Box, spec: DDPGCriticSpec, initializer_spec: dict - ) -> Tuple[QValueEnsemble]: - def make_critic(): - return MLPQValue(obs_space, action_space, spec.critic.encoder) - - def make_critic_ensemble(): - n_critics = 2 if spec.critic.double_q else 1 - critics = [make_critic() for _ in range(n_critics)] - - if spec.critic.parallelize: - return ForkedQValueEnsemble(critics) - return QValueEnsemble(critics) - - critics = make_critic_ensemble() - critics.initialize_parameters(initializer_spec) - target_critics = make_critic_ensemble() - target_critics.load_state_dict(critics) - return critics, target_critics + critic = ActionValueCritic(obs_space, action_space, spec.critic) + self.critics = critic.q_values + self.target_critics = critic.target_q_values diff --git a/raylab/modules/networks/actor/__init__.py b/raylab/modules/networks/actor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/raylab/modules/networks/actor/deterministic.py b/raylab/modules/networks/actor/deterministic.py new file mode 100644 index 00000000..422808ae --- /dev/null +++ b/raylab/modules/networks/actor/deterministic.py @@ -0,0 +1,109 @@ +"""Network and configurations for modules with deterministic policies.""" +import warnings +from dataclasses import dataclass +from dataclasses import field + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .policy.deterministic import MLPDeterministicPolicy + +MLPSpec = MLPDeterministicPolicy.spec_cls + + +@dataclass +class DeterministicActorSpec(DataClassJsonMixin): + """Specifications for policy, behavior, and target policy. + + Args: + encoder: Specifications for creating the multilayer perceptron mapping + states to pre-action linear features + norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't + normalize actions before squashing function + behavior: Type of behavior policy. Either 'gaussian', 'parameter_noise', + or 'deterministic' + smooth_target_policy: Whether to use a noisy target policy for + Q-Learning + target_gaussian_sigma: Gaussian standard deviation for noisy target + policy + separate_target_policy: Whether to use separate parameters for the + target policy. Intended for use with polyak averaging + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + + encoder: MLPSpec = field(default_factory=MLPSpec) + norm_beta: float = 1.2 + behavior: str = "gaussian" + smooth_target_policy: bool = True + target_gaussian_sigma: float = 0.3 + separate_target_policy: bool = False + initializer: dict = field(default_factory=dict) + + def __post_init__(self): + cls_name = type(self).__name__ + assert self.norm_beta > 0, f"{cls_name}.norm_beta must be positive" + valid_behaviors = {"gaussian", "parameter_noise", "deterministic"} + assert ( + self.behavior in valid_behaviors + ), f"{cls_name}.behavior must be one of {valid_behaviors}" + assert ( + self.target_gaussian_sigma > 0 + ), f"{cls_name}.target_gaussian_sigma must be positive" + + +class DeterministicActor(nn.Module): + """NN with deterministic policies. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for policy, behavior, and target policy + + Attributes: + policy: The deterministic policy to be learned + behavior: The policy for exploration + target_policy: The policy used for estimating the arg max in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = DeterministicActorSpec + + def __init__( + self, obs_space: Box, action_space: Box, spec: DeterministicActorSpec, + ): + super().__init__() + + def make_policy(): + return MLPDeterministicPolicy( + obs_space, action_space, spec.actor.encoder, spec.actor.norm_beta + ) + + policy = make_policy() + policy.initialize_parameters(spec.initializer) + + behavior = policy + if spec.behavior == "parameter_noise": + if not spec.encoder.layer_norm: + warnings.warn( + f"Behavior is set to {spec.behavior} but layer normalization is " + "deactivated. Use layer normalization for better stability." + ) + behavior = make_policy() + behavior.load_state_dict(policy.state_dict()) + + target_policy = policy + if spec.separate_target_policy: + target_policy = make_policy() + target_policy.load_state_dict(policy.state_dict()) + if spec.smooth_target_policy: + target_policy = MLPDeterministicPolicy.add_gaussian_noise( + target_policy, noise_stddev=spec.target_gaussian_sigma + ) + + self.policy = policy + self.behavior = behavior + self.target_policy = target_policy diff --git a/raylab/modules/networks/actor/policy/__init__.py b/raylab/modules/networks/actor/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/raylab/modules/networks/policy/deterministic.py b/raylab/modules/networks/actor/policy/deterministic.py similarity index 96% rename from raylab/modules/networks/policy/deterministic.py rename to raylab/modules/networks/actor/policy/deterministic.py index 5bf40747..843837c8 100644 --- a/raylab/modules/networks/policy/deterministic.py +++ b/raylab/modules/networks/actor/policy/deterministic.py @@ -1,7 +1,6 @@ """Parameterized deterministic policies.""" import warnings from dataclasses import dataclass -from typing import Callable from typing import List from typing import Optional @@ -60,10 +59,6 @@ def unsquash_action(self, action: Tensor) -> Tensor: """Returns the unconstrained action which generated the given action.""" return self.squashing(action, reverse=True) - def initialize_(self, initializer: Callable[[nn.Module], None]): - """Apply initializer to encoder layers.""" - self.encoder.apply(initializer) - @classmethod def add_gaussian_noise(cls, policy, noise_stddev: float): """Adds a zero-mean Gaussian noise module to a DeterministicPolicy. @@ -113,8 +108,13 @@ class MLPDeterministicPolicy(DeterministicPolicy): mlp_spec: Multilayer perceptron specifications norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't normalize actions before squashing function. + + Attributes: + spec_cls: Expected class of `spec` init argument """ + spec_cls = StateMLPSpec + def __init__( self, obs_space: Box, diff --git a/raylab/modules/networks/critic/__init__.py b/raylab/modules/networks/critic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/raylab/modules/networks/critic/action_value.py b/raylab/modules/networks/critic/action_value.py new file mode 100644 index 00000000..d3ea04ed --- /dev/null +++ b/raylab/modules/networks/critic/action_value.py @@ -0,0 +1,81 @@ +"""Network and configurations for modules with Q-value critics.""" +from dataclasses import dataclass +from dataclasses import field + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .q_value import ForkedQValueEnsemble +from .q_value import MLPQValue +from .q_value import QValueEnsemble + + +QValueSpec = MLPQValue.spec_cls + + +@dataclass +class ActionValueCriticSpec(DataClassJsonMixin): + """Specifications for action-value estimators. + + Args: + encoder: Specifications for creating the multilayer perceptron mapping + states and actions to pre-value function linear features + double_q: Whether to create two Q-value estimators instead of one. + Defaults to True + parallelize: Whether to evaluate Q-values in parallel. Defaults to + False. + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + + encoder: QValueSpec = field(default_factory=QValueSpec) + double_q: bool = True + parallelize: bool = False + initializer: dict = field(default_factory=dict) + + +class ActionValueCritic(nn.Module): + """NN with Q-value estimators. + + Since it is common to use clipped double Q-Learning, `q_values` is a + ModuleList of Q-value functions. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for action-value estimators + + Attributes: + q_values: The action-value estimators to be learned + target_q_values: The action-value estimators used for bootstrapping in + Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = ActionValueCriticSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: ActionValueCriticSpec): + super().__init__() + + def make_q_value(): + return MLPQValue(obs_space, action_space, spec.q_value.encoder) + + def make_q_value_ensemble(): + n_q_values = 2 if spec.q_value.double_q else 1 + q_values = [make_q_value() for _ in range(n_q_values)] + + if spec.q_value.parallelize: + return ForkedQValueEnsemble(q_values) + return QValueEnsemble(q_values) + + q_values = make_q_value_ensemble() + q_values.initialize_parameters(spec.initializer) + + target_q_values = make_q_value_ensemble() + target_q_values.load_state_dict(q_values) + + self.q_values = q_values + self.target_q_values = target_q_values diff --git a/raylab/modules/networks/action_value.py b/raylab/modules/networks/critic/q_value.py similarity index 96% rename from raylab/modules/networks/action_value.py rename to raylab/modules/networks/critic/q_value.py index baaf86fd..99c900ab 100644 --- a/raylab/modules/networks/action_value.py +++ b/raylab/modules/networks/critic/q_value.py @@ -63,6 +63,8 @@ class MLPQValue(QValue): mlp_spec: Multilayer perceptron specifications """ + spec_cls = StateActionMLPSpec + def __init__(self, obs_space: Box, action_space: Box, mlp_spec: StateActionMLPSpec): obs_size = obs_space.shape[0] action_size = action_space.shape[0] @@ -75,7 +77,7 @@ def __init__(self, obs_space: Box, action_space: Box, mlp_spec: StateActionMLPSp delay_action=mlp_spec.delay_action, ) super().__init__(encoder) - self.mlp_spec = mlp_spec.activation + self.spec = mlp_spec.activation def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. @@ -88,9 +90,7 @@ def initialize_parameters(self, initializer_spec: dict): to the initializer function name in `torch.nn.init` and optional keyword arguments. """ - initializer = initialize_( - activation=self.mlp_spec.activation, **initializer_spec - ) + initializer = initialize_(activation=self.spec.activation, **initializer_spec) self.encoder.apply(initializer) diff --git a/raylab/modules/networks/policy/__init__.py b/raylab/modules/networks/policy/__init__.py deleted file mode 100644 index 2964694e..00000000 --- a/raylab/modules/networks/policy/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Policies as neural network modules.""" From 12c30b911f3ecd81d2dd48e638f95c4baf7da470 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 10:24:57 -0300 Subject: [PATCH 09/48] fix(networks): set encoder defaults for actor and critic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- .../networks/actor/policy/deterministic.py | 7 ++-- raylab/modules/networks/critic/q_value.py | 7 ++-- tests/modules/test_ddpg.py | 34 +++++++++++++++++++ 3 files changed, 42 insertions(+), 6 deletions(-) create mode 100644 tests/modules/test_ddpg.py diff --git a/raylab/modules/networks/actor/policy/deterministic.py b/raylab/modules/networks/actor/policy/deterministic.py index 843837c8..9703a4b8 100644 --- a/raylab/modules/networks/actor/policy/deterministic.py +++ b/raylab/modules/networks/actor/policy/deterministic.py @@ -1,6 +1,7 @@ """Parameterized deterministic policies.""" import warnings from dataclasses import dataclass +from dataclasses import field from typing import List from typing import Optional @@ -91,9 +92,9 @@ class StateMLPSpec: and following activation """ - units: List[int] - activation: Optional[str] - layer_norm: bool + units: List[int] = field(default_factory=list) + activation: Optional[str] = None + layer_norm: bool = False class MLPDeterministicPolicy(DeterministicPolicy): diff --git a/raylab/modules/networks/critic/q_value.py b/raylab/modules/networks/critic/q_value.py index 99c900ab..e850b366 100644 --- a/raylab/modules/networks/critic/q_value.py +++ b/raylab/modules/networks/critic/q_value.py @@ -1,5 +1,6 @@ """Parameterized action-value estimators.""" from dataclasses import dataclass +from dataclasses import field from typing import List from typing import Optional @@ -49,9 +50,9 @@ class StateActionMLPSpec(DataClassJsonMixin): observation before concatenating the action to the input. """ - units: List[int] - activation: Optional[str] - delay_action: bool + units: List[int] = field(default_factory=list) + activation: Optional[str] = None + delay_action: bool = False class MLPQValue(QValue): diff --git a/tests/modules/test_ddpg.py b/tests/modules/test_ddpg.py new file mode 100644 index 00000000..1545fb80 --- /dev/null +++ b/tests/modules/test_ddpg.py @@ -0,0 +1,34 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +import torch.nn as nn + +from raylab.modules.ddpg import DDPG + + +@pytest.fixture +def spec_cls(): + return DDPG.spec_cls + + +@pytest.fixture +def module(obs_space, action_space, spec_cls): + return DDPG(obs_space, action_space, spec_cls()) + + +def test_spec(spec_cls): + default_config = spec_cls().to_dict() + + for key in ["actor", "critic", "initializer"]: + assert key in default_config + + +def test_init(module): + assert isinstance(module, nn.Module) + + for attr in ["actor", "behavior", "target_actor", "critics", "target_critics"]: + assert hasattr(module, attr) + + +def test_script(module): + torch.jit.script(module) From 105a0d10e33e15592aba2756d02faf0c764577a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 12:00:21 -0300 Subject: [PATCH 10/48] refactor(modules): move old NNs to v0 submodule MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/catalog.py | 81 +++++++++++-------- raylab/modules/v0/__init__.py | 0 raylab/modules/{ => v0}/abstract.py | 0 raylab/modules/{ => v0}/ddpg_module.py | 0 raylab/modules/{ => v0}/maxent_model_based.py | 0 raylab/modules/{ => v0}/mixins/__init__.py | 0 .../{ => v0}/mixins/action_value_mixin.py | 0 .../mixins/deterministic_actor_mixin.py | 0 .../mixins/normalizing_flow_actor_mixin.py | 2 +- .../mixins/normalizing_flow_model_mixin.py | 2 +- .../{ => v0}/mixins/state_value_mixin.py | 0 .../{ => v0}/mixins/stochastic_actor_mixin.py | 0 .../{ => v0}/mixins/stochastic_model_mixin.py | 0 .../{ => v0}/mixins/svg_model_mixin.py | 2 +- raylab/modules/{ => v0}/model_based_ddpg.py | 0 raylab/modules/{ => v0}/model_based_sac.py | 0 raylab/modules/{ => v0}/naf_module.py | 0 raylab/modules/{ => v0}/nfmbrl.py | 0 raylab/modules/{ => v0}/off_policy_nfac.py | 0 .../{ => v0}/on_policy_actor_critic.py | 0 raylab/modules/{ => v0}/on_policy_nfac.py | 0 raylab/modules/{ => v0}/sac_module.py | 0 raylab/modules/{ => v0}/simple_model_based.py | 0 raylab/modules/{ => v0}/svg_module.py | 0 raylab/modules/{ => v0}/svg_realnvp_actor.py | 0 raylab/modules/{ => v0}/trpo_tang2018.py | 2 +- tests/modules/v0/__init__.py | 0 .../{ => v0}/test_action_value_mixin.py | 2 +- .../test_deterministic_actor_mixin.py | 2 +- tests/modules/{ => v0}/test_naf_module.py | 2 +- .../test_normalizing_flow_actor_mixin.py | 2 +- .../test_normalizing_flow_model_mixin.py | 2 +- .../{ => v0}/test_state_value_mixin.py | 2 +- .../{ => v0}/test_stochastic_actor_mixin.py | 2 +- .../{ => v0}/test_stochastic_model_mixin.py | 2 +- tests/modules/{ => v0}/test_svg_module.py | 2 +- .../modules/{ => v0}/test_trpo_extensions.py | 0 tests/modules/{ => v0}/utils.py | 0 tests/{ => pytorch/nn}/modules/test_made.py | 0 39 files changed, 59 insertions(+), 48 deletions(-) create mode 100644 raylab/modules/v0/__init__.py rename raylab/modules/{ => v0}/abstract.py (100%) rename raylab/modules/{ => v0}/ddpg_module.py (100%) rename raylab/modules/{ => v0}/maxent_model_based.py (100%) rename raylab/modules/{ => v0}/mixins/__init__.py (100%) rename raylab/modules/{ => v0}/mixins/action_value_mixin.py (100%) rename raylab/modules/{ => v0}/mixins/deterministic_actor_mixin.py (100%) rename raylab/modules/{ => v0}/mixins/normalizing_flow_actor_mixin.py (99%) rename raylab/modules/{ => v0}/mixins/normalizing_flow_model_mixin.py (99%) rename raylab/modules/{ => v0}/mixins/state_value_mixin.py (100%) rename raylab/modules/{ => v0}/mixins/stochastic_actor_mixin.py (100%) rename raylab/modules/{ => v0}/mixins/stochastic_model_mixin.py (100%) rename raylab/modules/{ => v0}/mixins/svg_model_mixin.py (95%) rename raylab/modules/{ => v0}/model_based_ddpg.py (100%) rename raylab/modules/{ => v0}/model_based_sac.py (100%) rename raylab/modules/{ => v0}/naf_module.py (100%) rename raylab/modules/{ => v0}/nfmbrl.py (100%) rename raylab/modules/{ => v0}/off_policy_nfac.py (100%) rename raylab/modules/{ => v0}/on_policy_actor_critic.py (100%) rename raylab/modules/{ => v0}/on_policy_nfac.py (100%) rename raylab/modules/{ => v0}/sac_module.py (100%) rename raylab/modules/{ => v0}/simple_model_based.py (100%) rename raylab/modules/{ => v0}/svg_module.py (100%) rename raylab/modules/{ => v0}/svg_realnvp_actor.py (100%) rename raylab/modules/{ => v0}/trpo_tang2018.py (99%) create mode 100644 tests/modules/v0/__init__.py rename tests/modules/{ => v0}/test_action_value_mixin.py (96%) rename tests/modules/{ => v0}/test_deterministic_actor_mixin.py (98%) rename tests/modules/{ => v0}/test_naf_module.py (92%) rename tests/modules/{ => v0}/test_normalizing_flow_actor_mixin.py (98%) rename tests/modules/{ => v0}/test_normalizing_flow_model_mixin.py (98%) rename tests/modules/{ => v0}/test_state_value_mixin.py (96%) rename tests/modules/{ => v0}/test_stochastic_actor_mixin.py (98%) rename tests/modules/{ => v0}/test_stochastic_model_mixin.py (98%) rename tests/modules/{ => v0}/test_svg_module.py (96%) rename tests/modules/{ => v0}/test_trpo_extensions.py (100%) rename tests/modules/{ => v0}/utils.py (100%) rename tests/{ => pytorch/nn}/modules/test_made.py (100%) diff --git a/raylab/modules/catalog.py b/raylab/modules/catalog.py index 99c0e7aa..6d8dad00 100644 --- a/raylab/modules/catalog.py +++ b/raylab/modules/catalog.py @@ -1,47 +1,58 @@ """Registry of modules for PyTorch policies.""" +import torch.nn as nn +from gym.spaces import Space from .ddpg import DDPG -from .ddpg_module import DDPGModule -from .maxent_model_based import MaxEntModelBased -from .model_based_ddpg import ModelBasedDDPG -from .model_based_sac import ModelBasedSAC -from .naf_module import NAFModule -from .nfmbrl import NFMBRL -from .off_policy_nfac import OffPolicyNFAC -from .on_policy_actor_critic import OnPolicyActorCritic -from .on_policy_nfac import OnPolicyNFAC -from .sac_module import SACModule -from .simple_model_based import SimpleModelBased -from .svg_module import SVGModule -from .svg_realnvp_actor import SVGRealNVPActor -from .trpo_tang2018 import TRPOTang2018 +from .v0.ddpg_module import DDPGModule +from .v0.maxent_model_based import MaxEntModelBased +from .v0.model_based_ddpg import ModelBasedDDPG +from .v0.model_based_sac import ModelBasedSAC +from .v0.naf_module import NAFModule +from .v0.nfmbrl import NFMBRL +from .v0.off_policy_nfac import OffPolicyNFAC +from .v0.on_policy_actor_critic import OnPolicyActorCritic +from .v0.on_policy_nfac import OnPolicyNFAC +from .v0.sac_module import SACModule +from .v0.simple_model_based import SimpleModelBased +from .v0.svg_module import SVGModule +from .v0.svg_realnvp_actor import SVGRealNVPActor +from .v0.trpo_tang2018 import TRPOTang2018 -MODULES = { - "NAFModule": NAFModule, - "DDPGModule": DDPGModule, - "SACModule": SACModule, - "SimpleModelBased": SimpleModelBased, - "SVGModule": SVGModule, - "MaxEntModelBased": MaxEntModelBased, - "ModelBasedDDPG": ModelBasedDDPG, - "ModelBasedSAC": ModelBasedSAC, - "NFMBRL": NFMBRL, - "OnPolicyActorCritic": OnPolicyActorCritic, - "OnPolicyNFAC": OnPolicyNFAC, - "OffPolicyNFAC": OffPolicyNFAC, - "TRPOTang2018": TRPOTang2018, - "SVGRealNVPActor": SVGRealNVPActor, +MODULESv0 = { + cls.__name__: cls + for cls in ( + NAFModule, + DDPGModule, + SACModule, + SimpleModelBased, + SVGModule, + MaxEntModelBased, + ModelBasedDDPG, + ModelBasedSAC, + NFMBRL, + OnPolicyActorCritic, + OnPolicyNFAC, + OffPolicyNFAC, + TRPOTang2018, + SVGRealNVPActor, + ) } -MODULESv2 = {k.__name__: k for k in [DDPG]} +MODULESv1 = {cls.__name__: cls for cls in (DDPG,)} -def get_module(obs_space, action_space, config): - """Retrieve and construct module of given name.""" +def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: + """Retrieve and construct module of given name. + + Args: + obs_space: Observation space + action_space: Action space + config: Configurations for module construction and initialization + """ type_ = config.pop("type") - if type_ in MODULES: - return MODULES[type_](obs_space, action_space, config) + if type_ in MODULESv0: + return MODULESv0[type_](obs_space, action_space, config) - cls = MODULESv2[type_] + cls = MODULESv1[type_] spec = cls.spec_cls.from_dict(config) return cls(obs_space, action_space, spec) diff --git a/raylab/modules/v0/__init__.py b/raylab/modules/v0/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/raylab/modules/abstract.py b/raylab/modules/v0/abstract.py similarity index 100% rename from raylab/modules/abstract.py rename to raylab/modules/v0/abstract.py diff --git a/raylab/modules/ddpg_module.py b/raylab/modules/v0/ddpg_module.py similarity index 100% rename from raylab/modules/ddpg_module.py rename to raylab/modules/v0/ddpg_module.py diff --git a/raylab/modules/maxent_model_based.py b/raylab/modules/v0/maxent_model_based.py similarity index 100% rename from raylab/modules/maxent_model_based.py rename to raylab/modules/v0/maxent_model_based.py diff --git a/raylab/modules/mixins/__init__.py b/raylab/modules/v0/mixins/__init__.py similarity index 100% rename from raylab/modules/mixins/__init__.py rename to raylab/modules/v0/mixins/__init__.py diff --git a/raylab/modules/mixins/action_value_mixin.py b/raylab/modules/v0/mixins/action_value_mixin.py similarity index 100% rename from raylab/modules/mixins/action_value_mixin.py rename to raylab/modules/v0/mixins/action_value_mixin.py diff --git a/raylab/modules/mixins/deterministic_actor_mixin.py b/raylab/modules/v0/mixins/deterministic_actor_mixin.py similarity index 100% rename from raylab/modules/mixins/deterministic_actor_mixin.py rename to raylab/modules/v0/mixins/deterministic_actor_mixin.py diff --git a/raylab/modules/mixins/normalizing_flow_actor_mixin.py b/raylab/modules/v0/mixins/normalizing_flow_actor_mixin.py similarity index 99% rename from raylab/modules/mixins/normalizing_flow_actor_mixin.py rename to raylab/modules/v0/mixins/normalizing_flow_actor_mixin.py index fb70318f..5c0bce2c 100644 --- a/raylab/modules/mixins/normalizing_flow_actor_mixin.py +++ b/raylab/modules/v0/mixins/normalizing_flow_actor_mixin.py @@ -6,11 +6,11 @@ import torch.nn as nn from ray.rllib.utils import override +import raylab.modules.networks as networks import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd from raylab.utils.dictionaries import deep_merge -from .. import networks from .stochastic_actor_mixin import StochasticPolicy diff --git a/raylab/modules/mixins/normalizing_flow_model_mixin.py b/raylab/modules/v0/mixins/normalizing_flow_model_mixin.py similarity index 99% rename from raylab/modules/mixins/normalizing_flow_model_mixin.py rename to raylab/modules/v0/mixins/normalizing_flow_model_mixin.py index 4f73ae7c..62015ea7 100644 --- a/raylab/modules/mixins/normalizing_flow_model_mixin.py +++ b/raylab/modules/v0/mixins/normalizing_flow_model_mixin.py @@ -5,11 +5,11 @@ import torch.nn as nn from ray.rllib.utils import override +import raylab.modules.networks as networks import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd from raylab.utils.dictionaries import deep_merge -from .. import networks from .stochastic_model_mixin import StochasticModel from .stochastic_model_mixin import StochasticModelMixin diff --git a/raylab/modules/mixins/state_value_mixin.py b/raylab/modules/v0/mixins/state_value_mixin.py similarity index 100% rename from raylab/modules/mixins/state_value_mixin.py rename to raylab/modules/v0/mixins/state_value_mixin.py diff --git a/raylab/modules/mixins/stochastic_actor_mixin.py b/raylab/modules/v0/mixins/stochastic_actor_mixin.py similarity index 100% rename from raylab/modules/mixins/stochastic_actor_mixin.py rename to raylab/modules/v0/mixins/stochastic_actor_mixin.py diff --git a/raylab/modules/mixins/stochastic_model_mixin.py b/raylab/modules/v0/mixins/stochastic_model_mixin.py similarity index 100% rename from raylab/modules/mixins/stochastic_model_mixin.py rename to raylab/modules/v0/mixins/stochastic_model_mixin.py diff --git a/raylab/modules/mixins/svg_model_mixin.py b/raylab/modules/v0/mixins/svg_model_mixin.py similarity index 95% rename from raylab/modules/mixins/svg_model_mixin.py rename to raylab/modules/v0/mixins/svg_model_mixin.py index cb17c754..5adbb3b7 100644 --- a/raylab/modules/mixins/svg_model_mixin.py +++ b/raylab/modules/v0/mixins/svg_model_mixin.py @@ -54,7 +54,7 @@ def make_param(in_features): kwargs = dict(event_size=1, input_dependent_scale=False) return nnx.NormalParams(in_features, **kwargs) - self.params = nn.ModuleList([make_param(l.out_features) for l in self.logits]) + self.params = nn.ModuleList([make_param(m.out_features) for m in self.logits]) @override(nn.Module) def forward(self, obs, act): # pylint: disable=arguments-differ diff --git a/raylab/modules/model_based_ddpg.py b/raylab/modules/v0/model_based_ddpg.py similarity index 100% rename from raylab/modules/model_based_ddpg.py rename to raylab/modules/v0/model_based_ddpg.py diff --git a/raylab/modules/model_based_sac.py b/raylab/modules/v0/model_based_sac.py similarity index 100% rename from raylab/modules/model_based_sac.py rename to raylab/modules/v0/model_based_sac.py diff --git a/raylab/modules/naf_module.py b/raylab/modules/v0/naf_module.py similarity index 100% rename from raylab/modules/naf_module.py rename to raylab/modules/v0/naf_module.py diff --git a/raylab/modules/nfmbrl.py b/raylab/modules/v0/nfmbrl.py similarity index 100% rename from raylab/modules/nfmbrl.py rename to raylab/modules/v0/nfmbrl.py diff --git a/raylab/modules/off_policy_nfac.py b/raylab/modules/v0/off_policy_nfac.py similarity index 100% rename from raylab/modules/off_policy_nfac.py rename to raylab/modules/v0/off_policy_nfac.py diff --git a/raylab/modules/on_policy_actor_critic.py b/raylab/modules/v0/on_policy_actor_critic.py similarity index 100% rename from raylab/modules/on_policy_actor_critic.py rename to raylab/modules/v0/on_policy_actor_critic.py diff --git a/raylab/modules/on_policy_nfac.py b/raylab/modules/v0/on_policy_nfac.py similarity index 100% rename from raylab/modules/on_policy_nfac.py rename to raylab/modules/v0/on_policy_nfac.py diff --git a/raylab/modules/sac_module.py b/raylab/modules/v0/sac_module.py similarity index 100% rename from raylab/modules/sac_module.py rename to raylab/modules/v0/sac_module.py diff --git a/raylab/modules/simple_model_based.py b/raylab/modules/v0/simple_model_based.py similarity index 100% rename from raylab/modules/simple_model_based.py rename to raylab/modules/v0/simple_model_based.py diff --git a/raylab/modules/svg_module.py b/raylab/modules/v0/svg_module.py similarity index 100% rename from raylab/modules/svg_module.py rename to raylab/modules/v0/svg_module.py diff --git a/raylab/modules/svg_realnvp_actor.py b/raylab/modules/v0/svg_realnvp_actor.py similarity index 100% rename from raylab/modules/svg_realnvp_actor.py rename to raylab/modules/v0/svg_realnvp_actor.py diff --git a/raylab/modules/trpo_tang2018.py b/raylab/modules/v0/trpo_tang2018.py similarity index 99% rename from raylab/modules/trpo_tang2018.py rename to raylab/modules/v0/trpo_tang2018.py index a93e3957..d76cef70 100644 --- a/raylab/modules/trpo_tang2018.py +++ b/raylab/modules/v0/trpo_tang2018.py @@ -7,6 +7,7 @@ from ray.rllib.utils import merge_dicts from ray.rllib.utils import override +import raylab.modules.networks as networks from raylab.pytorch.nn import FullyConnected from raylab.pytorch.nn.distributions import flows from raylab.pytorch.nn.distributions import Independent @@ -16,7 +17,6 @@ from raylab.pytorch.nn.distributions.flows import TanhSquashTransform from raylab.pytorch.nn.init import initialize_ -from . import networks from .abstract import AbstractActorCritic from .mixins import StateValueMixin from .mixins import StochasticActorMixin diff --git a/tests/modules/v0/__init__.py b/tests/modules/v0/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/test_action_value_mixin.py b/tests/modules/v0/test_action_value_mixin.py similarity index 96% rename from tests/modules/test_action_value_mixin.py rename to tests/modules/v0/test_action_value_mixin.py index b5a3e97e..a230050d 100644 --- a/tests/modules/test_action_value_mixin.py +++ b/tests/modules/v0/test_action_value_mixin.py @@ -6,7 +6,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.mixins import ActionValueMixin +from raylab.modules.v0.mixins import ActionValueMixin class DummyModule(ActionValueMixin, nn.ModuleDict): diff --git a/tests/modules/test_deterministic_actor_mixin.py b/tests/modules/v0/test_deterministic_actor_mixin.py similarity index 98% rename from tests/modules/test_deterministic_actor_mixin.py rename to tests/modules/v0/test_deterministic_actor_mixin.py index db583b75..37ed59e5 100644 --- a/tests/modules/test_deterministic_actor_mixin.py +++ b/tests/modules/v0/test_deterministic_actor_mixin.py @@ -5,7 +5,7 @@ from ray.rllib import SampleBatch from ray.rllib.utils import merge_dicts -from raylab.modules.mixins import DeterministicActorMixin +from raylab.modules.v0.mixins import DeterministicActorMixin BASE_CONFIG = { diff --git a/tests/modules/test_naf_module.py b/tests/modules/v0/test_naf_module.py similarity index 92% rename from tests/modules/test_naf_module.py rename to tests/modules/v0/test_naf_module.py index ae1b80aa..95889065 100644 --- a/tests/modules/test_naf_module.py +++ b/tests/modules/v0/test_naf_module.py @@ -4,7 +4,7 @@ import pytest import torch -from raylab.modules.naf_module import NAFModule +from raylab.modules.v0.naf_module import NAFModule @pytest.fixture(params=(True, False), ids=("Double Q", "Single Q")) diff --git a/tests/modules/test_normalizing_flow_actor_mixin.py b/tests/modules/v0/test_normalizing_flow_actor_mixin.py similarity index 98% rename from tests/modules/test_normalizing_flow_actor_mixin.py rename to tests/modules/v0/test_normalizing_flow_actor_mixin.py index 856d686d..c61f0e2e 100644 --- a/tests/modules/test_normalizing_flow_actor_mixin.py +++ b/tests/modules/v0/test_normalizing_flow_actor_mixin.py @@ -6,7 +6,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.mixins import NormalizingFlowActorMixin +from raylab.modules.v0.mixins import NormalizingFlowActorMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/test_normalizing_flow_model_mixin.py b/tests/modules/v0/test_normalizing_flow_model_mixin.py similarity index 98% rename from tests/modules/test_normalizing_flow_model_mixin.py rename to tests/modules/v0/test_normalizing_flow_model_mixin.py index 5a2d6386..98e590af 100644 --- a/tests/modules/test_normalizing_flow_model_mixin.py +++ b/tests/modules/v0/test_normalizing_flow_model_mixin.py @@ -7,7 +7,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.mixins import NormalizingFlowModelMixin +from raylab.modules.v0.mixins import NormalizingFlowModelMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/test_state_value_mixin.py b/tests/modules/v0/test_state_value_mixin.py similarity index 96% rename from tests/modules/test_state_value_mixin.py rename to tests/modules/v0/test_state_value_mixin.py index 229ea3fc..2b69f255 100644 --- a/tests/modules/test_state_value_mixin.py +++ b/tests/modules/v0/test_state_value_mixin.py @@ -6,7 +6,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.mixins import StateValueMixin +from raylab.modules.v0.mixins import StateValueMixin class DummyModule(StateValueMixin, nn.ModuleDict): diff --git a/tests/modules/test_stochastic_actor_mixin.py b/tests/modules/v0/test_stochastic_actor_mixin.py similarity index 98% rename from tests/modules/test_stochastic_actor_mixin.py rename to tests/modules/v0/test_stochastic_actor_mixin.py index 65025e60..346303ba 100644 --- a/tests/modules/test_stochastic_actor_mixin.py +++ b/tests/modules/v0/test_stochastic_actor_mixin.py @@ -7,7 +7,7 @@ from gym.spaces import Discrete from ray.rllib import SampleBatch -from raylab.modules.mixins import StochasticActorMixin +from raylab.modules.v0.mixins import StochasticActorMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/test_stochastic_model_mixin.py b/tests/modules/v0/test_stochastic_model_mixin.py similarity index 98% rename from tests/modules/test_stochastic_model_mixin.py rename to tests/modules/v0/test_stochastic_model_mixin.py index ba2fe89f..43cb9d9e 100644 --- a/tests/modules/test_stochastic_model_mixin.py +++ b/tests/modules/v0/test_stochastic_model_mixin.py @@ -4,7 +4,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.mixins import StochasticModelMixin +from raylab.modules.v0.mixins import StochasticModelMixin class DummyModule(StochasticModelMixin, nn.ModuleDict): diff --git a/tests/modules/test_svg_module.py b/tests/modules/v0/test_svg_module.py similarity index 96% rename from tests/modules/test_svg_module.py rename to tests/modules/v0/test_svg_module.py index 9595b078..22c7b456 100644 --- a/tests/modules/test_svg_module.py +++ b/tests/modules/v0/test_svg_module.py @@ -3,7 +3,7 @@ import torch from ray.rllib import SampleBatch -from raylab.modules.svg_module import SVGModule +from raylab.modules.v0.svg_module import SVGModule @pytest.fixture diff --git a/tests/modules/test_trpo_extensions.py b/tests/modules/v0/test_trpo_extensions.py similarity index 100% rename from tests/modules/test_trpo_extensions.py rename to tests/modules/v0/test_trpo_extensions.py diff --git a/tests/modules/utils.py b/tests/modules/v0/utils.py similarity index 100% rename from tests/modules/utils.py rename to tests/modules/v0/utils.py diff --git a/tests/modules/test_made.py b/tests/pytorch/nn/modules/test_made.py similarity index 100% rename from tests/modules/test_made.py rename to tests/pytorch/nn/modules/test_made.py From 09abd40da78710e2809e1384ea7677054bfd51b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 12:37:57 -0300 Subject: [PATCH 11/48] fix(networks): use deterministic actor spec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/networks/actor/deterministic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/raylab/modules/networks/actor/deterministic.py b/raylab/modules/networks/actor/deterministic.py index 422808ae..5d36465d 100644 --- a/raylab/modules/networks/actor/deterministic.py +++ b/raylab/modules/networks/actor/deterministic.py @@ -79,7 +79,7 @@ def __init__( def make_policy(): return MLPDeterministicPolicy( - obs_space, action_space, spec.actor.encoder, spec.actor.norm_beta + obs_space, action_space, spec.encoder, spec.norm_beta ) policy = make_policy() From 9bd649e4bbe45fb87e423a40fc8ee76a7bc7d201 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 19:43:35 -0300 Subject: [PATCH 12/48] feat(pytorch): add no-op initializer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/pytorch/nn/init.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/raylab/pytorch/nn/init.py b/raylab/pytorch/nn/init.py index d4168ce2..fe0f1a5f 100644 --- a/raylab/pytorch/nn/init.py +++ b/raylab/pytorch/nn/init.py @@ -1,16 +1,23 @@ """Utilities for module initialization.""" import functools import inspect +from typing import Callable +from typing import Optional +from typing import Union import torch.nn as nn +from torch import Tensor -def get_initializer(name): +def get_initializer(name: Optional[str]) -> Callable[[Tensor], None]: """Return initializer function given its name. Arguments: - name (str): the initializer function's name + name: The initializer function's name. If None, returns a no-op callable """ + if name is None: + return lambda _: None + name_ = name + "_" if name in dir(nn.init) and name_ in dir(nn.init): func = getattr(nn.init, name_) @@ -27,16 +34,19 @@ def get_initializer(name): } -def initialize_(name, activation=None, **options): +def initialize_( + name: Optional[str] = None, activation: Union[str, dict] = None, **options +) -> Callable[[nn.Module], None]: """Return a callable to apply an initializer with the given name and options. If `gain` is part of the initializer's argspec and is not specified in options, the recommended value from `torch.nn.init.calculate_gain` is used. Arguments: - name (str): name of initializer function - activation (str, dict): activation function following linear layer, optional - **options: keyword arguments to be passed to the initializer + name: Initializer function name + activation: Optional specification of the activation function that + follows linear layers + **options: Keyword arguments to pass to the initializer Returns: A callable to be used with `nn.Module.apply`. From dc18a0f3ba30eb65494a5cb4464ecc8f76880c41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 19:44:59 -0300 Subject: [PATCH 13/48] fix(networks): add Gaussian noise if requested MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/networks/actor/deterministic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/raylab/modules/networks/actor/deterministic.py b/raylab/modules/networks/actor/deterministic.py index 5d36465d..41b4a609 100644 --- a/raylab/modules/networks/actor/deterministic.py +++ b/raylab/modules/networks/actor/deterministic.py @@ -7,6 +7,7 @@ from dataclasses_json import DataClassJsonMixin from gym.spaces import Box +from .policy.deterministic import DeterministicPolicy from .policy.deterministic import MLPDeterministicPolicy MLPSpec = MLPDeterministicPolicy.spec_cls @@ -100,7 +101,7 @@ def make_policy(): target_policy = make_policy() target_policy.load_state_dict(policy.state_dict()) if spec.smooth_target_policy: - target_policy = MLPDeterministicPolicy.add_gaussian_noise( + target_policy = DeterministicPolicy.add_gaussian_noise( target_policy, noise_stddev=spec.target_gaussian_sigma ) From ff62d98f2dc7c612540ffbb71529f40dd1db5a2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 20:27:46 -0300 Subject: [PATCH 14/48] feat(pytorch): add reverse mode to TanhSquash MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/pytorch/nn/modules/tanh_squash.py | 14 +++++--- tests/pytorch/nn/modules/test_tanh_squash.py | 36 +++++++++++++------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/raylab/pytorch/nn/modules/tanh_squash.py b/raylab/pytorch/nn/modules/tanh_squash.py index 3bbfd968..909c283b 100644 --- a/raylab/pytorch/nn/modules/tanh_squash.py +++ b/raylab/pytorch/nn/modules/tanh_squash.py @@ -1,16 +1,22 @@ # pylint: disable=missing-docstring +import torch import torch.nn as nn -from ray.rllib.utils import override +from torch import Tensor class TanhSquash(nn.Module): """Neural network module squashing vectors to specified range using Tanh.""" - def __init__(self, low, high): + def __init__(self, low: Tensor, high: Tensor): super().__init__() self.register_buffer("loc", (high + low) / 2) self.register_buffer("scale", (high - low) / 2) - @override(nn.Module) - def forward(self, inputs): # pylint: disable=arguments-differ + def forward(self, inputs: Tensor, reverse: bool = False) -> Tensor: + # pylint: disable=arguments-differ + if reverse: + inputs = (inputs - self.loc) / self.scale + to_log1 = torch.clamp(1 + inputs, min=1.1754943508222875e-38) + to_log2 = torch.clamp(1 - inputs, min=1.1754943508222875e-38) + return (torch.log(to_log1) - torch.log(to_log2)) / 2 return self.loc + inputs.tanh() * self.scale diff --git a/tests/pytorch/nn/modules/test_tanh_squash.py b/tests/pytorch/nn/modules/test_tanh_squash.py index b1913aef..cc10cc32 100644 --- a/tests/pytorch/nn/modules/test_tanh_squash.py +++ b/tests/pytorch/nn/modules/test_tanh_squash.py @@ -13,29 +13,39 @@ def low_high(request): @pytest.fixture -def maker(torch_script): - def factory(*args, **kwargs): - module = TanhSquash(*args, **kwargs) - return torch.jit.script(module) if torch_script else module +def squash(low_high): + low, high = low_high + return TanhSquash(low, high) + + +@pytest.fixture +def module(squash, torch_script): + if torch_script: + return torch.jit.script(squash) + return squash + - return factory +@pytest.fixture +def inputs(low_high): + low, _ = low_high + return torch.randn(10, *low.shape) -def test_squash_to_range(maker, low_high): +def test_squash_to_range(module, low_high, inputs): low, high = low_high - module = maker(low, high) - inputs = torch.randn(10, *low.shape) output = module(inputs) assert (output <= high).all() assert (output >= low).all() -def test_propagates_gradients(maker, low_high): - low, high = low_high - module = maker(low, high) - - inputs = torch.randn(10, *low.shape, requires_grad=True) +def test_propagates_gradients(module, inputs): + inputs.requires_grad_() module(inputs).mean().backward() assert inputs.grad is not None assert (inputs.grad != 0).any() + + +def test_reverse(module, inputs): + squashed = module(inputs) + assert torch.allclose(module(squashed, reverse=True), inputs, atol=1e-6) From 459fdca8f31a84cdbcf1752b0e5fd6ad22a76316 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 20:32:22 -0300 Subject: [PATCH 15/48] fix(modules): create and script new DDPG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/ddpg.py | 20 +++++++----------- .../networks/actor/policy/deterministic.py | 2 +- .../modules/networks/critic/action_value.py | 8 +++---- raylab/modules/networks/critic/q_value.py | 21 ++++++++++++++----- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/raylab/modules/ddpg.py b/raylab/modules/ddpg.py index 5dabe9a3..3e45a509 100644 --- a/raylab/modules/ddpg.py +++ b/raylab/modules/ddpg.py @@ -7,9 +7,7 @@ from gym.spaces import Box from .networks.actor.deterministic import DeterministicActor -from .networks.actor.policy.deterministic import DeterministicPolicy from .networks.critic.action_value import ActionValueCritic -from .networks.critic.q_value import QValueEnsemble ActorSpec = DeterministicActor.spec_cls CriticSpec = ActionValueCritic.spec_cls @@ -41,21 +39,17 @@ class DDPG(nn.Module): spec: Specifications for DDPG modules Attributes: - actor: The deterministic policy to be learned - behavior: The policy for exploration - target_actor: The policy used for estimating the arg max in Q-Learning - critics: The action-value estimators to be learned - target_critics: The action-value estimators used for bootstrapping in - Q-Learning + actor (DeterministicPolicy): The deterministic policy to be learned + behavior (DeterministicPolicy): The policy for exploration + target_actor (DeterministicPolicy): The policy used for estimating the + arg max in Q-Learning + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning spec_cls: Expected class of `spec` init argument """ # pylint:disable=abstract-method - actor: DeterministicPolicy - behavior: DeterministicPolicy - target_actor: DeterministicPolicy - critics: QValueEnsemble - target_critics: QValueEnsemble spec_cls = DDPGSpec def __init__(self, obs_space: Box, action_space: Box, spec: DDPGSpec): diff --git a/raylab/modules/networks/actor/policy/deterministic.py b/raylab/modules/networks/actor/policy/deterministic.py index 9703a4b8..4159e6b4 100644 --- a/raylab/modules/networks/actor/policy/deterministic.py +++ b/raylab/modules/networks/actor/policy/deterministic.py @@ -51,7 +51,7 @@ def unconstrained_action(self, obs: Tensor) -> Tensor: """Forward pass with no squashing at the end""" features = self.encoder(obs) unconstrained_action = self.action_linear(features) - if self.noise: + if self.noise is not None: unconstrained_action = self.noise(unconstrained_action) return unconstrained_action diff --git a/raylab/modules/networks/critic/action_value.py b/raylab/modules/networks/critic/action_value.py index d3ea04ed..9953a8a3 100644 --- a/raylab/modules/networks/critic/action_value.py +++ b/raylab/modules/networks/critic/action_value.py @@ -61,13 +61,13 @@ def __init__(self, obs_space: Box, action_space: Box, spec: ActionValueCriticSpe super().__init__() def make_q_value(): - return MLPQValue(obs_space, action_space, spec.q_value.encoder) + return MLPQValue(obs_space, action_space, spec.encoder) def make_q_value_ensemble(): - n_q_values = 2 if spec.q_value.double_q else 1 + n_q_values = 2 if spec.double_q else 1 q_values = [make_q_value() for _ in range(n_q_values)] - if spec.q_value.parallelize: + if spec.parallelize: return ForkedQValueEnsemble(q_values) return QValueEnsemble(q_values) @@ -75,7 +75,7 @@ def make_q_value_ensemble(): q_values.initialize_parameters(spec.initializer) target_q_values = make_q_value_ensemble() - target_q_values.load_state_dict(q_values) + target_q_values.load_state_dict(q_values.state_dict()) self.q_values = q_values self.target_q_values = target_q_values diff --git a/raylab/modules/networks/critic/q_value.py b/raylab/modules/networks/critic/q_value.py index e850b366..accf86f1 100644 --- a/raylab/modules/networks/critic/q_value.py +++ b/raylab/modules/networks/critic/q_value.py @@ -66,19 +66,19 @@ class MLPQValue(QValue): spec_cls = StateActionMLPSpec - def __init__(self, obs_space: Box, action_space: Box, mlp_spec: StateActionMLPSpec): + def __init__(self, obs_space: Box, action_space: Box, spec: StateActionMLPSpec): obs_size = obs_space.shape[0] action_size = action_space.shape[0] encoder = nnx.StateActionEncoder( obs_size, action_size, - units=mlp_spec.units, - activation=mlp_spec.activation, - delay_action=mlp_spec.delay_action, + units=spec.units, + activation=spec.activation, + delay_action=spec.delay_action, ) super().__init__(encoder) - self.spec = mlp_spec.activation + self.spec = spec def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. @@ -126,6 +126,17 @@ def forward(self, obs: Tensor, action: Tensor, clip: bool = False) -> Tensor: action_values, _ = action_values.min(keepdim=True, dim=-1) return action_values + def initialize_parameters(self, initializer_spec: dict): + """Initialize each Q estimator in the ensemble. + + Args: + initializer_spec: Dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + for q_value in self: + q_value.initialize_parameters(initializer_spec) + class ForkedQValueEnsemble(QValueEnsemble): """Ensemble of Q-value estimators with parallelized forward pass.""" From d746d616e03bacb8183dfc1e3a009ee6ac51ef39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 21:05:55 -0300 Subject: [PATCH 16/48] fix(losses): replace module mixin imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/losses/mle.py | 2 +- raylab/losses/policy_gradient.py | 2 +- raylab/losses/svg.py | 4 ++-- tests/losses/conftest.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/raylab/losses/mle.py b/raylab/losses/mle.py index 6a5524b3..585393d4 100644 --- a/raylab/losses/mle.py +++ b/raylab/losses/mle.py @@ -8,7 +8,7 @@ from ray.rllib import SampleBatch from torch import Tensor -from raylab.modules.mixins.stochastic_model_mixin import StochasticModel +from raylab.modules.v0.mixins.stochastic_model_mixin import StochasticModel from raylab.utils.dictionaries import get_keys from .abstract import Loss diff --git a/raylab/losses/policy_gradient.py b/raylab/losses/policy_gradient.py index 4a37832e..620a0d68 100644 --- a/raylab/losses/policy_gradient.py +++ b/raylab/losses/policy_gradient.py @@ -8,7 +8,7 @@ from ray.rllib import SampleBatch from torch import Tensor -from raylab.modules.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy from raylab.utils.annotations import DetPolicy from raylab.utils.annotations import DynamicsFn from raylab.utils.annotations import RewardFn diff --git a/raylab/losses/svg.py b/raylab/losses/svg.py index a749feb9..cbd97f1b 100644 --- a/raylab/losses/svg.py +++ b/raylab/losses/svg.py @@ -10,8 +10,8 @@ from ray.rllib.utils import override from torch import Tensor -from raylab.modules.mixins.stochastic_actor_mixin import StochasticPolicy -from raylab.modules.mixins.stochastic_model_mixin import StochasticModel +from raylab.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.modules.v0.mixins.stochastic_model_mixin import StochasticModel from raylab.utils.annotations import RewardFn from raylab.utils.annotations import StateValue from raylab.utils.dictionaries import get_keys diff --git a/tests/losses/conftest.py b/tests/losses/conftest.py index a9623a81..c740c662 100644 --- a/tests/losses/conftest.py +++ b/tests/losses/conftest.py @@ -5,10 +5,10 @@ import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd -from raylab.modules.mixins.action_value_mixin import ActionValueFunction -from raylab.modules.mixins.deterministic_actor_mixin import DeterministicPolicy -from raylab.modules.mixins.stochastic_actor_mixin import StochasticPolicy -from raylab.modules.mixins.stochastic_model_mixin import StochasticModelMixin +from raylab.modules.v0.mixins.action_value_mixin import ActionValueFunction +from raylab.modules.v0.mixins.deterministic_actor_mixin import DeterministicPolicy +from raylab.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.modules.v0.mixins.stochastic_model_mixin import StochasticModelMixin from raylab.utils.debug import fake_batch From be05c52a10d7899ef7f86aa1ff9006bcea05653b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sun, 28 Jun 2020 21:06:25 -0300 Subject: [PATCH 17/48] test(networks): add deterministic actor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- tests/modules/networks/actor/__init__.py | 0 .../modules/networks/actor/policy/__init__.py | 0 .../actor/policy/test_deterministic.py | 38 +++++++++++ .../networks/actor/test_deterministic.py | 63 +++++++++++++++++++ tests/modules/networks/conftest.py | 11 ++++ 5 files changed, 112 insertions(+) create mode 100644 tests/modules/networks/actor/__init__.py create mode 100644 tests/modules/networks/actor/policy/__init__.py create mode 100644 tests/modules/networks/actor/policy/test_deterministic.py create mode 100644 tests/modules/networks/actor/test_deterministic.py create mode 100644 tests/modules/networks/conftest.py diff --git a/tests/modules/networks/actor/__init__.py b/tests/modules/networks/actor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/networks/actor/policy/__init__.py b/tests/modules/networks/actor/policy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/networks/actor/policy/test_deterministic.py b/tests/modules/networks/actor/policy/test_deterministic.py new file mode 100644 index 00000000..43682da0 --- /dev/null +++ b/tests/modules/networks/actor/policy/test_deterministic.py @@ -0,0 +1,38 @@ +# pylint:disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.modules.networks.actor.policy.deterministic import ( + MLPDeterministicPolicy, + ) + + return MLPDeterministicPolicy + + +@pytest.fixture(params=(0.1, 1.2), ids=lambda x: f"NormBeta({x})") +def norm_beta(request): + return request.param + + +@pytest.fixture +def spec(module_cls): + return module_cls.spec_cls() + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec, norm_beta): + return module_cls(obs_space, action_space, spec, norm_beta) + + +def test_unconstrained_action(module, batch, action_space, norm_beta): + action_dim = action_space.shape[0] + + policy_out = module.unconstrained_action(batch[SampleBatch.CUR_OBS]) + norms = policy_out.norm(p=1, dim=-1, keepdim=True) / action_dim + assert policy_out.shape[-1] == action_dim + assert policy_out.dtype == torch.float32 + assert (norms <= (norm_beta + torch.finfo(torch.float32).eps)).all() diff --git a/tests/modules/networks/actor/test_deterministic.py b/tests/modules/networks/actor/test_deterministic.py new file mode 100644 index 00000000..e71da01f --- /dev/null +++ b/tests/modules/networks/actor/test_deterministic.py @@ -0,0 +1,63 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.modules.networks.actor.deterministic import DeterministicActor + + return DeterministicActor + + +@pytest.fixture(params=(True, False), ids=lambda x: f"SeparateTargetPolicy({x})") +def separate_target_policy(request): + return request.param + + +@pytest.fixture(params="gaussian deterministic parameter_noise".split()) +def behavior(request): + return request.param + + +@pytest.fixture +def spec(module_cls, behavior, separate_target_policy): + return module_cls.spec_cls( + behavior=behavior, separate_target_policy=separate_target_policy + ) + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec): + return module_cls(obs_space, action_space, spec) + + +def test_module_creation(module): + for attr in "policy behavior target_policy".split(): + assert hasattr(module, attr) + + policy, target_policy = module.policy, module.target_policy + assert all( + torch.allclose(p, p_) + for p, p_ in zip(policy.parameters(), target_policy.parameters()) + ) + + +def test_separate_target_policy(module, spec): + policy, target = module.policy, module.target_policy + + if spec.separate_target_policy: + assert all(p is not t for p, t in zip(policy.parameters(), target.parameters())) + else: + assert all(p is t for p, t in zip(policy.parameters(), target.parameters())) + + +def test_behavior(module, batch, spec): + action = batch[SampleBatch.ACTIONS] + + samples = module.behavior(batch[SampleBatch.CUR_OBS]) + samples_ = module.behavior(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == torch.float32 + assert spec.behavior == "gaussian" or torch.allclose(samples, samples_) diff --git a/tests/modules/networks/conftest.py b/tests/modules/networks/conftest.py new file mode 100644 index 00000000..23005dc6 --- /dev/null +++ b/tests/modules/networks/conftest.py @@ -0,0 +1,11 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch + + +@pytest.fixture(scope="module") +def batch(obs_space, action_space): + from raylab.utils.debug import fake_batch + + samples = fake_batch(obs_space, action_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} From 5b237db891bcf8d4ff3abe565661de8ca828fac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Mon, 29 Jun 2020 07:32:55 -0300 Subject: [PATCH 18/48] test(networks): add action-value critic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- tests/modules/networks/critic/__init__.py | 0 .../networks/critic/test_action_value.py | 50 +++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 tests/modules/networks/critic/__init__.py create mode 100644 tests/modules/networks/critic/test_action_value.py diff --git a/tests/modules/networks/critic/__init__.py b/tests/modules/networks/critic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/networks/critic/test_action_value.py b/tests/modules/networks/critic/test_action_value.py new file mode 100644 index 00000000..fe4a0510 --- /dev/null +++ b/tests/modules/networks/critic/test_action_value.py @@ -0,0 +1,50 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.modules.networks.critic.action_value import ActionValueCritic + + return ActionValueCritic + + +@pytest.fixture(params=(True, False), ids="DoubleQ SingleQ".split()) +def double_q(request): + return request.param + + +@pytest.fixture +def spec(module_cls, double_q): + return module_cls.spec_cls(double_q=double_q) + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec): + return module_cls(obs_space, action_space, spec) + + +def test_module_creation(module, batch, spec): + double_q = spec.double_q + + for attr in "q_values target_q_values".split(): + assert hasattr(module, attr) + expected_n_critics = 2 if double_q else 1 + assert len(module.q_values) == expected_n_critics + + q_values, targets = module.q_values, module.target_q_values + vals = [ + m(batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS]) + for ensemble in (q_values, targets) + for m in ensemble + ] + for val in vals: + assert val.shape[-1] == 1 + assert val.dtype == torch.float32 + + assert all( + torch.allclose(p, t) + for p, t in zip(q_values.parameters(), targets.parameters()) + ) From 9998d9fc8d9e6e56118ab16d882003c4c4e1beed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Mon, 29 Jun 2020 07:52:11 -0300 Subject: [PATCH 19/48] test(networks): add stochastic actor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- .../modules/networks/actor/test_stochastic.py | 176 ++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 tests/modules/networks/actor/test_stochastic.py diff --git a/tests/modules/networks/actor/test_stochastic.py b/tests/modules/networks/actor/test_stochastic.py new file mode 100644 index 00000000..7312732e --- /dev/null +++ b/tests/modules/networks/actor/test_stochastic.py @@ -0,0 +1,176 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from gym.spaces import Box +from gym.spaces import Discrete +from ray.rllib import SampleBatch + +from raylab.utils.debug import fake_batch + +pytest.skip(reason="Not implemented") + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.modules.networks.actor.stochastic import StochasticActor + + return StochasticActor + + +@pytest.fixture(params=(True, False), ids=lambda x: f"InputDependentScale({x})") +def input_dependent_scale(request): + return request.param + + +@pytest.fixture +def cont_spec(module_cls, input_dependent_scale): + return module_cls.spec_cls(input_dependent_scale=input_dependent_scale) + + +@pytest.fixture +def spec(module_cls): + return module_cls.spec_cls() + + +DISC_SPACES = (Discrete(2), Discrete(8)) +CONT_SPACES = (Box(-1, 1, shape=(1,)), Box(-1, 1, shape=(3,))) +ACTION_SPACES = CONT_SPACES + DISC_SPACES + + +@pytest.fixture(params=DISC_SPACES, ids=(repr(a) for a in DISC_SPACES)) +def disc_space(request): + return request.param + + +@pytest.fixture(params=CONT_SPACES, ids=(repr(a) for a in CONT_SPACES)) +def cont_space(request): + return request.param + + +@pytest.fixture(params=ACTION_SPACES, ids=(repr(a) for a in ACTION_SPACES)) +def action_space(request): + return request.param + + +@pytest.fixture +def disc_module(module_cls, obs_space, disc_space, spec, torch_script): + mod = module_cls(obs_space, disc_space, spec) + return torch.jit.script(mod) if torch_script else mod + + +@pytest.fixture +def cont_module(module_cls, obs_space, cont_space, spec, torch_script): + mod = module_cls(obs_space, cont_space, spec) + return torch.jit.script(mod) if torch_script else mod + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec, torch_script): + mod = module_cls(obs_space, action_space, spec) + return torch.jit.script(mod) if torch_script else mod + + +@pytest.fixture +def disc_batch(obs_space, disc_space): + samples = fake_batch(obs_space, disc_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} + + +@pytest.fixture +def cont_batch(obs_space, cont_space): + samples = fake_batch(obs_space, cont_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} + + +@pytest.fixture +def batch(obs_space, action_space): + samples = fake_batch(obs_space, action_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} + + +def test_discrete_sampler(disc_module, disc_batch): + module, batch = disc_module, disc_batch + action = batch[SampleBatch.ACTIONS] + + sampler = module.actor.sample + samples, logp = sampler(batch[SampleBatch.CUR_OBS]) + samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == action.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_continuous_sampler(cont_module, cont_batch): + module = cont_module + batch = cont_batch + action = batch[SampleBatch.ACTIONS] + + sampler = module.actor.rsample + samples, logp = sampler(batch[SampleBatch.CUR_OBS]) + samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == action.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_discrete_params(disc_module, disc_batch): + module, batch = disc_module, disc_batch + + params = module.actor(batch[SampleBatch.CUR_OBS]) + assert "logits" in params + logits = params["logits"] + assert logits.shape[-1] == disc_space.n + + pi_params = set(module.actor.parameters()) + for par in pi_params: + par.grad = None + logits.mean().backward() + assert any(p.grad is not None for p in pi_params) + assert all(p.grad is None for p in set(module.parameters()) - pi_params) + + +def test_continuous_params(cont_module, cont_batch): + module, batch = cont_module, cont_batch + params = module.actor(batch[SampleBatch.CUR_OBS]) + assert "loc" in params + assert "scale" in params + + loc, scale = params["loc"], params["scale"] + action = batch[SampleBatch.ACTIONS] + assert loc.shape == action.shape + assert scale.shape == action.shape + assert loc.dtype == torch.float32 + assert scale.dtype == torch.float32 + + pi_params = set(module.actor.parameters()) + for par in pi_params: + par.grad = None + loc.mean().backward() + assert any(p.grad is not None for p in pi_params) + assert all(p.grad is None for p in set(module.parameters()) - pi_params) + + for par in pi_params: + par.grad = None + module.actor(batch[SampleBatch.CUR_OBS])["scale"].mean().backward() + assert any(p.grad is not None for p in pi_params) + assert all(p.grad is None for p in set(module.parameters()) - pi_params) + + +def test_reproduce(cont_module, cont_batch): + module, batch = cont_module, cont_batch + + acts = batch[SampleBatch.ACTIONS] + acts_, logp_ = module.actor.reproduce(batch[SampleBatch.CUR_OBS], acts) + assert acts_.shape == acts.shape + assert acts_.dtype == acts.dtype + assert torch.allclose(acts_, acts, atol=1e-5) + assert logp_.shape == batch[SampleBatch.REWARDS].shape + + acts_.mean().backward() + pi_params = set(module.actor.parameters()) + assert all(p.grad is not None for p in pi_params) + assert all(p.grad is None for p in set(module.parameters()) - pi_params) From 3de5ed8d815d8e6e60f73ab7d073410aef44ef85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Mon, 29 Jun 2020 13:00:26 -0300 Subject: [PATCH 20/48] feat(networks): add stochastic policies and actor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- .../networks/actor/policy/deterministic.py | 42 +--- .../networks/actor/policy/state_mlp.py | 49 +++++ .../networks/actor/policy/stochastic.py | 198 ++++++++++++++++++ raylab/modules/networks/actor/stochastic.py | 78 +++++++ tests/modules/networks/actor/conftest.py | 45 ++++ .../actor/policy/test_deterministic.py | 5 + .../networks/actor/policy/test_stochastic.py | 128 +++++++++++ .../networks/actor/test_deterministic.py | 10 + .../modules/networks/actor/test_stochastic.py | 134 +----------- 9 files changed, 525 insertions(+), 164 deletions(-) create mode 100644 raylab/modules/networks/actor/policy/state_mlp.py create mode 100644 raylab/modules/networks/actor/policy/stochastic.py create mode 100644 raylab/modules/networks/actor/stochastic.py create mode 100644 tests/modules/networks/actor/conftest.py create mode 100644 tests/modules/networks/actor/policy/test_stochastic.py diff --git a/raylab/modules/networks/actor/policy/deterministic.py b/raylab/modules/networks/actor/policy/deterministic.py index 4159e6b4..286fcb64 100644 --- a/raylab/modules/networks/actor/policy/deterministic.py +++ b/raylab/modules/networks/actor/policy/deterministic.py @@ -1,8 +1,5 @@ """Parameterized deterministic policies.""" import warnings -from dataclasses import dataclass -from dataclasses import field -from typing import List from typing import Optional import torch @@ -13,6 +10,8 @@ import raylab.pytorch.nn as nnx from raylab.pytorch.nn.init import initialize_ +from .state_mlp import StateMLP + class DeterministicPolicy(nn.Module): """Continuous action deterministic policy as a sequence of modules. @@ -81,22 +80,6 @@ def add_gaussian_noise(cls, policy, noise_stddev: float): return cls(policy.encoder, policy.action_linear, policy.squashing, noise=noise) -@dataclass -class StateMLPSpec: - """Specifications for creating a multilayer perceptron. - - Args: - units: Number of units in each hidden layer - activation: Nonlinearity following each linear layer - layer_norm: Whether to apply layer normalization between each linear layer - and following activation - """ - - units: List[int] = field(default_factory=list) - activation: Optional[str] = None - layer_norm: bool = False - - class MLPDeterministicPolicy(DeterministicPolicy): """DeterministicPolicy with multilayer perceptron encoder. @@ -114,28 +97,18 @@ class MLPDeterministicPolicy(DeterministicPolicy): spec_cls: Expected class of `spec` init argument """ - spec_cls = StateMLPSpec + spec_cls = StateMLP.spec_cls def __init__( self, obs_space: Box, action_space: Box, - mlp_spec: StateMLPSpec, + mlp_spec: StateMLP.spec_cls, norm_beta: float, ): - obs_size = obs_space.shape[0] - action_size = action_space.shape[0] - action_low, action_high = map( - torch.as_tensor, (action_space.low, action_space.high) - ) - - encoder = nnx.FullyConnected( - obs_size, - mlp_spec.units, - mlp_spec.activation, - layer_norm=mlp_spec.layer_norm, - ) + encoder = StateMLP(obs_space, mlp_spec).encoder + action_size = action_space.shape[0] if norm_beta: action_linear = nnx.NormalizedLinear( encoder.out_features, action_size, beta=norm_beta @@ -143,6 +116,9 @@ def __init__( else: action_linear = nn.Linear(encoder.out_features, action_size) + action_low, action_high = map( + torch.as_tensor, (action_space.low, action_space.high) + ) squash = nnx.TanhSquash(action_low, action_high) super().__init__(encoder, action_linear, squash) diff --git a/raylab/modules/networks/actor/policy/state_mlp.py b/raylab/modules/networks/actor/policy/state_mlp.py new file mode 100644 index 00000000..add46733 --- /dev/null +++ b/raylab/modules/networks/actor/policy/state_mlp.py @@ -0,0 +1,49 @@ +# pylint:disable=missing-module-docstring +from dataclasses import dataclass +from dataclasses import field +from typing import List +from typing import Optional + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box +from torch import Tensor + +import raylab.pytorch.nn as nnx + + +@dataclass +class StateMLPSpec(DataClassJsonMixin): + """Specifications for creating a multilayer perceptron. + + Args: + units: Number of units in each hidden layer + activation: Nonlinearity following each linear layer + layer_norm: Whether to apply layer normalization between each linear layer + and following activation + """ + + units: List[int] = field(default_factory=list) + activation: Optional[str] = None + layer_norm: bool = False + + +class StateMLP(nn.Module): + """Multilayer perceptron for encoding state inputs. + + Attributes: + encoder: Fully connected module with multiple layers + """ + + spec_cls = StateMLPSpec + + def __init__(self, obs_space: Box, spec: StateMLPSpec): + super().__init__() + obs_size = obs_space.shape[0] + self.encoder = nnx.FullyConnected( + obs_size, spec.units, spec.activation, layer_norm=spec.layer_norm, + ) + + def forward(self, obs: Tensor) -> Tensor: + # pylint:disable=arguments-differ + return self.encoder(obs) diff --git a/raylab/modules/networks/actor/policy/stochastic.py b/raylab/modules/networks/actor/policy/stochastic.py new file mode 100644 index 00000000..ac27f03d --- /dev/null +++ b/raylab/modules/networks/actor/policy/stochastic.py @@ -0,0 +1,198 @@ +"""Parameterized stochastic policies.""" +from typing import Callable +from typing import List + +import torch +import torch.nn as nn +from gym.spaces import Box +from gym.spaces import Discrete + +import raylab.pytorch.nn as nnx +import raylab.pytorch.nn.distributions as ptd + +from .state_mlp import StateMLP + + +class StochasticPolicy(nn.Module): + """Represents a stochastic policy as a conditional distribution module.""" + + # pylint:disable=abstract-method + + def __init__(self, params_module, dist_module): + super().__init__() + self.params = params_module + self.dist = dist_module + + def forward(self, obs): # pylint:disable=arguments-differ + return self.params(obs) + + @torch.jit.export + def sample(self, obs, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. Returns a (sample, log_prob) + pair. + """ + params = self(obs) + return self.dist.sample(params, sample_shape) + + @torch.jit.export + def rsample(self, obs, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. Returns a (rsample, log_prob) pair. + """ + params = self(obs) + return self.dist.rsample(params, sample_shape) + + @torch.jit.export + def log_prob(self, obs, action): + """ + Returns the log of the probability density/mass function evaluated at `action`. + """ + params = self(obs) + return self.dist.log_prob(action, params) + + @torch.jit.export + def cdf(self, obs, action): + """Returns the cumulative density/mass function evaluated at `action`.""" + params = self(obs) + return self.dist.cdf(action, params) + + @torch.jit.export + def icdf(self, obs, prob): + """Returns the inverse cumulative density/mass function evaluated at `prob`.""" + params = self(obs) + return self.dist.icdf(prob, params) + + @torch.jit.export + def entropy(self, obs): + """Returns entropy of distribution.""" + params = self(obs) + return self.dist.entropy(params) + + @torch.jit.export + def perplexity(self, obs): + """Returns perplexity of distribution.""" + params = self(obs) + return self.dist.perplexity(params) + + @torch.jit.export + def reproduce(self, obs, action): + """Produce a reparametrized sample with the same value as `action`.""" + params = self(obs) + return self.dist.reproduce(action, params) + + @torch.jit.export + def deterministic(self, obs): + """ + Generates a deterministic sample or batch of samples if the distribution + parameters are batched. Returns a (rsample, log_prob) pair. + """ + params = self(obs) + return self.dist.deterministic(params) + + +class MLPStochasticPolicy(StochasticPolicy): + """Stochastic policy with multilayer perceptron state encoder. + + Args: + obs_space: Observation space + spec: Specifications for the encoder + params_fn: Callable that builds a module for computing distribution + parameters given the number of state features + dist: Conditional distribution module + """ + + spec_cls = StateMLP.spec_cls + + def __init__( + self, + obs_space: Box, + spec: StateMLP.spec_cls, + params_fn: Callable[[int], nn.Module], + dist: ptd.ConditionalDistribution, + ): + encoder = StateMLP(obs_space, spec).encoder + params = params_fn(encoder.out_features) + params_module = nn.Sequential(encoder, params) + super().__init__(params_module, dist) + + +class MLPContinuousPolicy(MLPStochasticPolicy): + """Multilayer perceptron policy for continuous actions. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Specifications for the multilayer perceptron + input_dependent_scale: Whether to parameterize the Gaussian standard + deviation as a function of the state + """ + + def __init__( + self, + obs_space: Box, + action_space: Box, + mlp_spec: MLPStochasticPolicy.spec_cls, + input_dependent_scale: bool, + ): + def params_fn(out_features): + return nnx.NormalParams( + out_features, + action_space.shape[0], + input_dependent_scale=input_dependent_scale, + ) + + dist = ptd.TransformedDistribution( + ptd.Independent(ptd.Normal(), reinterpreted_batch_ndims=1), + ptd.flows.TanhSquashTransform( + low=torch.as_tensor(action_space.low), + high=torch.as_tensor(action_space.high), + event_dim=1, + ), + ) + super().__init__(obs_space, mlp_spec, params_fn, dist) + + +class MLPDiscretePolicy(MLPStochasticPolicy): + """Multilayer perceptron policy for discrete actions. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Specifications for the multilayer perceptron + """ + + def __init__( + self, + obs_space: Box, + action_space: Discrete, + mlp_spec: MLPStochasticPolicy.spec_cls, + ): + def params_fn(out_features): + return nnx.CategoricalParams(out_features, action_space.n) + + dist = ptd.Categorical() + super().__init__(obs_space, mlp_spec, params_fn, dist) + + +class Alpha(nn.Module): + """Wraps a single scalar coefficient parameter. + + Allows learning said coefficient by having it as a parameter + + Args: + initial_alpha: Value to initialize the coefficient to + + Attributes: + lob_alpha: Natural logarithm of the current coefficient + """ + + def __init__(self, initial_alpha: float): + super().__init__() + self.log_alpha = nn.Parameter(torch.empty([]).fill_(initial_alpha).log()) + + def forward(self): # pylint:disable=arguments-differ + return self.log_alpha.exp() diff --git a/raylab/modules/networks/actor/stochastic.py b/raylab/modules/networks/actor/stochastic.py new file mode 100644 index 00000000..e0a2f6a2 --- /dev/null +++ b/raylab/modules/networks/actor/stochastic.py @@ -0,0 +1,78 @@ +"""Network and configurations for modules with stochastic policies.""" +import warnings +from dataclasses import dataclass +from dataclasses import field +from typing import Union + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box +from gym.spaces import Discrete + +from .policy.stochastic import Alpha +from .policy.stochastic import MLPContinuousPolicy +from .policy.stochastic import MLPDiscretePolicy +from .policy.stochastic import MLPStochasticPolicy + +MLPSpec = MLPStochasticPolicy.spec_cls + + +@dataclass +class StochasticActorSpec(DataClassJsonMixin): + """Specifications for stochastic policy. + + Args: + encoder: Specifications for building the multilayer perceptron state + processor + input_dependent_scale: Whether to parameterize the Gaussian standard + deviation as a function of the state + initial_entropy_coeff: Optional initial value of the entropy bonus term. + The actor creates an `alpha` attribute with this initial value. + """ + + encoder: MLPSpec = field(default_factory=MLPSpec) + input_dependent_scale: bool = False + initial_entropy_coeff: float = 0.0 + + def __post_init__(self): + cls_name = type(self).__name__ + ent_coeff = self.initial_entropy_coeff + if ent_coeff < 0: + warnings.warn(f"Entropy coefficient is negative in {cls_name}: {ent_coeff}") + + +class StochasticActor(nn.Module): + """NN with stochastic policy. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for stochastic policy + + Attributes: + policy: Stochastic policy to be learned + alpha: Entropy bonus coefficient + """ + + # pylint:disable=abstract-method + spec_cls = StochasticActorSpec + + def __init__( + self, + obs_space: Box, + action_space: Union[Box, Discrete], + spec: StochasticActorSpec, + ): + super().__init__() + + if isinstance(action_space, Box): + policy = MLPContinuousPolicy( + obs_space, action_space, spec.encoder, spec.input_dependent_scale + ) + elif isinstance(action_space, Discrete): + policy = MLPDiscretePolicy(obs_space, action_space, spec.encoder) + else: + raise ValueError(f"Unsopported action space type {type(action_space)}") + + self.policy = policy + self.alpha = Alpha(spec.initial_entropy_coeff) diff --git a/tests/modules/networks/actor/conftest.py b/tests/modules/networks/actor/conftest.py new file mode 100644 index 00000000..09782b58 --- /dev/null +++ b/tests/modules/networks/actor/conftest.py @@ -0,0 +1,45 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from gym.spaces import Box +from gym.spaces import Discrete + +from raylab.utils.debug import fake_batch + + +DISC_SPACES = (Discrete(2), Discrete(8)) +CONT_SPACES = (Box(-1, 1, shape=(1,)), Box(-1, 1, shape=(3,))) +ACTION_SPACES = CONT_SPACES + DISC_SPACES + + +@pytest.fixture(params=DISC_SPACES, ids=(repr(a) for a in DISC_SPACES)) +def disc_space(request): + return request.param + + +@pytest.fixture(params=CONT_SPACES, ids=(repr(a) for a in CONT_SPACES)) +def cont_space(request): + return request.param + + +@pytest.fixture(params=ACTION_SPACES, ids=(repr(a) for a in ACTION_SPACES)) +def action_space(request): + return request.param + + +@pytest.fixture +def disc_batch(obs_space, disc_space): + samples = fake_batch(obs_space, disc_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} + + +@pytest.fixture +def cont_batch(obs_space, cont_space): + samples = fake_batch(obs_space, cont_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} + + +@pytest.fixture +def batch(obs_space, action_space): + samples = fake_batch(obs_space, action_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} diff --git a/tests/modules/networks/actor/policy/test_deterministic.py b/tests/modules/networks/actor/policy/test_deterministic.py index 43682da0..77dacbea 100644 --- a/tests/modules/networks/actor/policy/test_deterministic.py +++ b/tests/modules/networks/actor/policy/test_deterministic.py @@ -23,6 +23,11 @@ def spec(module_cls): return module_cls.spec_cls() +@pytest.fixture +def action_space(cont_space): + return cont_space + + @pytest.fixture def module(module_cls, obs_space, action_space, spec, norm_beta): return module_cls(obs_space, action_space, spec, norm_beta) diff --git a/tests/modules/networks/actor/policy/test_stochastic.py b/tests/modules/networks/actor/policy/test_stochastic.py new file mode 100644 index 00000000..69bced1b --- /dev/null +++ b/tests/modules/networks/actor/policy/test_stochastic.py @@ -0,0 +1,128 @@ +# pylint:disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module") +def base_cls(): + from raylab.modules.networks.actor.policy.stochastic import MLPStochasticPolicy + + return MLPStochasticPolicy + + +@pytest.fixture(scope="module") +def cont_cls(): + from raylab.modules.networks.actor.policy.stochastic import MLPContinuousPolicy + + return MLPContinuousPolicy + + +@pytest.fixture(scope="module") +def disc_cls(): + from raylab.modules.networks.actor.policy.stochastic import MLPDiscretePolicy + + return MLPDiscretePolicy + + +@pytest.fixture +def spec(base_cls): + return base_cls.spec_cls() + + +@pytest.fixture(params=(True, False), ids=lambda x: f"InputDependentScale({x})") +def input_dependent_scale(request): + return request.param + + +@pytest.fixture +def cont_policy(cont_cls, obs_space, cont_space, spec, input_dependent_scale): + return cont_cls(obs_space, cont_space, spec, input_dependent_scale) + + +@pytest.fixture +def disc_policy(disc_cls, obs_space, disc_space, spec): + return disc_cls(obs_space, disc_space, spec) + + +def test_continuous_sample(cont_policy, cont_batch): + policy, batch = cont_policy, cont_batch + action = batch[SampleBatch.ACTIONS] + + sampler = policy.rsample + samples, logp = sampler(batch[SampleBatch.CUR_OBS]) + samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == action.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_discrete_sample(disc_policy, disc_batch): + policy, batch = disc_policy, disc_batch + action = batch[SampleBatch.ACTIONS] + + sampler = policy.sample + samples, logp = sampler(batch[SampleBatch.CUR_OBS]) + samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == action.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_continuous_params(cont_policy, cont_batch): + policy, batch = cont_policy, cont_batch + params = policy(batch[SampleBatch.CUR_OBS]) + assert "loc" in params + assert "scale" in params + + loc, scale = params["loc"], params["scale"] + action = batch[SampleBatch.ACTIONS] + assert loc.shape == action.shape + assert scale.shape == action.shape + assert loc.dtype == torch.float32 + assert scale.dtype == torch.float32 + + pi_params = set(policy.parameters()) + for par in pi_params: + par.grad = None + loc.mean().backward() + assert any(p.grad is not None for p in pi_params) + + for par in pi_params: + par.grad = None + policy(batch[SampleBatch.CUR_OBS])["scale"].mean().backward() + assert any(p.grad is not None for p in pi_params) + + +def test_discrete_params(disc_policy, disc_space, disc_batch): + policy, batch = disc_policy, disc_batch + + params = policy(batch[SampleBatch.CUR_OBS]) + assert "logits" in params + logits = params["logits"] + assert logits.shape[-1] == disc_space.n + + pi_params = set(policy.parameters()) + for par in pi_params: + par.grad = None + logits.mean().backward() + assert any(p.grad is not None for p in pi_params) + + +def test_reproduce(cont_policy, cont_batch): + policy, batch = cont_policy, cont_batch + + acts = batch[SampleBatch.ACTIONS] + acts_, logp_ = policy.reproduce(batch[SampleBatch.CUR_OBS], acts) + assert acts_.shape == acts.shape + assert acts_.dtype == acts.dtype + assert torch.allclose(acts_, acts, atol=1e-5) + assert logp_.shape == batch[SampleBatch.REWARDS].shape + + acts_.mean().backward() + pi_params = set(policy.parameters()) + assert all(p.grad is not None for p in pi_params) diff --git a/tests/modules/networks/actor/test_deterministic.py b/tests/modules/networks/actor/test_deterministic.py index e71da01f..3e6ec146 100644 --- a/tests/modules/networks/actor/test_deterministic.py +++ b/tests/modules/networks/actor/test_deterministic.py @@ -4,6 +4,16 @@ from ray.rllib import SampleBatch +@pytest.fixture +def action_space(cont_space): + return cont_space + + +@pytest.fixture +def batch(cont_batch): + return cont_batch + + @pytest.fixture(scope="module") def module_cls(): from raylab.modules.networks.actor.deterministic import DeterministicActor diff --git a/tests/modules/networks/actor/test_stochastic.py b/tests/modules/networks/actor/test_stochastic.py index 7312732e..6a0c3455 100644 --- a/tests/modules/networks/actor/test_stochastic.py +++ b/tests/modules/networks/actor/test_stochastic.py @@ -1,13 +1,6 @@ # pylint: disable=missing-docstring,redefined-outer-name,protected-access import pytest import torch -from gym.spaces import Box -from gym.spaces import Discrete -from ray.rllib import SampleBatch - -from raylab.utils.debug import fake_batch - -pytest.skip(reason="Not implemented") @pytest.fixture(scope="module") @@ -32,26 +25,6 @@ def spec(module_cls): return module_cls.spec_cls() -DISC_SPACES = (Discrete(2), Discrete(8)) -CONT_SPACES = (Box(-1, 1, shape=(1,)), Box(-1, 1, shape=(3,))) -ACTION_SPACES = CONT_SPACES + DISC_SPACES - - -@pytest.fixture(params=DISC_SPACES, ids=(repr(a) for a in DISC_SPACES)) -def disc_space(request): - return request.param - - -@pytest.fixture(params=CONT_SPACES, ids=(repr(a) for a in CONT_SPACES)) -def cont_space(request): - return request.param - - -@pytest.fixture(params=ACTION_SPACES, ids=(repr(a) for a in ACTION_SPACES)) -def action_space(request): - return request.param - - @pytest.fixture def disc_module(module_cls, obs_space, disc_space, spec, torch_script): mod = module_cls(obs_space, disc_space, spec) @@ -70,107 +43,6 @@ def module(module_cls, obs_space, action_space, spec, torch_script): return torch.jit.script(mod) if torch_script else mod -@pytest.fixture -def disc_batch(obs_space, disc_space): - samples = fake_batch(obs_space, disc_space, batch_size=32) - return {k: torch.from_numpy(v) for k, v in samples.items()} - - -@pytest.fixture -def cont_batch(obs_space, cont_space): - samples = fake_batch(obs_space, cont_space, batch_size=32) - return {k: torch.from_numpy(v) for k, v in samples.items()} - - -@pytest.fixture -def batch(obs_space, action_space): - samples = fake_batch(obs_space, action_space, batch_size=32) - return {k: torch.from_numpy(v) for k, v in samples.items()} - - -def test_discrete_sampler(disc_module, disc_batch): - module, batch = disc_module, disc_batch - action = batch[SampleBatch.ACTIONS] - - sampler = module.actor.sample - samples, logp = sampler(batch[SampleBatch.CUR_OBS]) - samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) - assert samples.shape == action.shape - assert samples.dtype == action.dtype - assert logp.shape == batch[SampleBatch.REWARDS].shape - assert logp.dtype == batch[SampleBatch.REWARDS].dtype - assert not torch.allclose(samples, samples_) - - -def test_continuous_sampler(cont_module, cont_batch): - module = cont_module - batch = cont_batch - action = batch[SampleBatch.ACTIONS] - - sampler = module.actor.rsample - samples, logp = sampler(batch[SampleBatch.CUR_OBS]) - samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) - assert samples.shape == action.shape - assert samples.dtype == action.dtype - assert logp.shape == batch[SampleBatch.REWARDS].shape - assert logp.dtype == batch[SampleBatch.REWARDS].dtype - assert not torch.allclose(samples, samples_) - - -def test_discrete_params(disc_module, disc_batch): - module, batch = disc_module, disc_batch - - params = module.actor(batch[SampleBatch.CUR_OBS]) - assert "logits" in params - logits = params["logits"] - assert logits.shape[-1] == disc_space.n - - pi_params = set(module.actor.parameters()) - for par in pi_params: - par.grad = None - logits.mean().backward() - assert any(p.grad is not None for p in pi_params) - assert all(p.grad is None for p in set(module.parameters()) - pi_params) - - -def test_continuous_params(cont_module, cont_batch): - module, batch = cont_module, cont_batch - params = module.actor(batch[SampleBatch.CUR_OBS]) - assert "loc" in params - assert "scale" in params - - loc, scale = params["loc"], params["scale"] - action = batch[SampleBatch.ACTIONS] - assert loc.shape == action.shape - assert scale.shape == action.shape - assert loc.dtype == torch.float32 - assert scale.dtype == torch.float32 - - pi_params = set(module.actor.parameters()) - for par in pi_params: - par.grad = None - loc.mean().backward() - assert any(p.grad is not None for p in pi_params) - assert all(p.grad is None for p in set(module.parameters()) - pi_params) - - for par in pi_params: - par.grad = None - module.actor(batch[SampleBatch.CUR_OBS])["scale"].mean().backward() - assert any(p.grad is not None for p in pi_params) - assert all(p.grad is None for p in set(module.parameters()) - pi_params) - - -def test_reproduce(cont_module, cont_batch): - module, batch = cont_module, cont_batch - - acts = batch[SampleBatch.ACTIONS] - acts_, logp_ = module.actor.reproduce(batch[SampleBatch.CUR_OBS], acts) - assert acts_.shape == acts.shape - assert acts_.dtype == acts.dtype - assert torch.allclose(acts_, acts, atol=1e-5) - assert logp_.shape == batch[SampleBatch.REWARDS].shape - - acts_.mean().backward() - pi_params = set(module.actor.parameters()) - assert all(p.grad is not None for p in pi_params) - assert all(p.grad is None for p in set(module.parameters()) - pi_params) +def test_init(module): + for attr in "policy alpha".split(): + assert hasattr(module, attr) From c99b9e44916f79c70eca6cebc821b9fd8a2581f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Mon, 29 Jun 2020 14:27:45 -0300 Subject: [PATCH 21/48] feat(modules): add refactored SAC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- .../networks/actor/policy/stochastic.py | 22 ++++++ raylab/modules/networks/actor/stochastic.py | 7 +- raylab/modules/sac.py | 67 +++++++++++++++++++ tests/modules/test_ddpg.py | 2 +- tests/modules/test_sac.py | 34 ++++++++++ 5 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 raylab/modules/sac.py create mode 100644 tests/modules/test_sac.py diff --git a/raylab/modules/networks/actor/policy/stochastic.py b/raylab/modules/networks/actor/policy/stochastic.py index ac27f03d..9562af3e 100644 --- a/raylab/modules/networks/actor/policy/stochastic.py +++ b/raylab/modules/networks/actor/policy/stochastic.py @@ -9,6 +9,7 @@ import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd +from raylab.pytorch.nn.init import initialize_ from .state_mlp import StateMLP @@ -103,6 +104,10 @@ class MLPStochasticPolicy(StochasticPolicy): params_fn: Callable that builds a module for computing distribution parameters given the number of state features dist: Conditional distribution module + + Attributes: + encoder: Multilayer perceptron state encoder + spec: MLP spec instance """ spec_cls = StateMLP.spec_cls @@ -119,6 +124,23 @@ def __init__( params_module = nn.Sequential(encoder, params) super().__init__(params_module, dist) + self.encoder = encoder + self.spec = spec + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + initializer = initialize_(activation=self.spec.activation, **initializer_spec) + self.encoder.apply(initializer) + class MLPContinuousPolicy(MLPStochasticPolicy): """Multilayer perceptron policy for continuous actions. diff --git a/raylab/modules/networks/actor/stochastic.py b/raylab/modules/networks/actor/stochastic.py index e0a2f6a2..088975e1 100644 --- a/raylab/modules/networks/actor/stochastic.py +++ b/raylab/modules/networks/actor/stochastic.py @@ -19,7 +19,7 @@ @dataclass class StochasticActorSpec(DataClassJsonMixin): - """Specifications for stochastic policy. + """Specifications for stochastic policy and entropy coefficient. Args: encoder: Specifications for building the multilayer perceptron state @@ -28,11 +28,15 @@ class StochasticActorSpec(DataClassJsonMixin): deviation as a function of the state initial_entropy_coeff: Optional initial value of the entropy bonus term. The actor creates an `alpha` attribute with this initial value. + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. """ encoder: MLPSpec = field(default_factory=MLPSpec) input_dependent_scale: bool = False initial_entropy_coeff: float = 0.0 + initializer: dict = field(default_factory=dict) def __post_init__(self): cls_name = type(self).__name__ @@ -73,6 +77,7 @@ def __init__( policy = MLPDiscretePolicy(obs_space, action_space, spec.encoder) else: raise ValueError(f"Unsopported action space type {type(action_space)}") + policy.initialize_parameters(spec.initializer) self.policy = policy self.alpha = Alpha(spec.initial_entropy_coeff) diff --git a/raylab/modules/sac.py b/raylab/modules/sac.py new file mode 100644 index 00000000..5ef60232 --- /dev/null +++ b/raylab/modules/sac.py @@ -0,0 +1,67 @@ +"""NN architecture used in Soft Actor-Critic.""" +from dataclasses import dataclass +from dataclasses import field + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .networks.actor.stochastic import StochasticActor +from .networks.critic.action_value import ActionValueCritic + +ActorSpec = StochasticActor.spec_cls +CriticSpec = ActionValueCritic.spec_cls + + +@dataclass +class SACSpec(DataClassJsonMixin): + """Specifications for SAC modules + + Args: + actor: Specifications for stochastic policy and entropy coefficient + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + + actor: ActorSpec = field(default_factory=ActorSpec) + critic: CriticSpec = field(default_factory=CriticSpec) + initializer: dict = field(default_factory=dict) + + +class SAC(nn.Module): + """NN module for Soft Actor-Critic algorithms. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for SAC modules + + Attributes: + actor (StochasticPolicy): Stochastic policy to be learned + alpha (Alpha): Entropy bonus coefficient + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = SACSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: SACSpec): + super().__init__() + # Top-level initializer options take precedence over individual + # component's options + if spec.initializer: + spec.actor.initializer = spec.initializer + spec.critic.initializer = spec.initializer + + actor = StochasticActor(obs_space, action_space, spec.actor) + self.actor = actor.policy + self.alpha = actor.alpha + + critic = ActionValueCritic(obs_space, action_space, spec.critic) + self.critics = critic.q_values + self.target_critics = critic.target_q_values diff --git a/tests/modules/test_ddpg.py b/tests/modules/test_ddpg.py index 1545fb80..d21c1a50 100644 --- a/tests/modules/test_ddpg.py +++ b/tests/modules/test_ddpg.py @@ -26,7 +26,7 @@ def test_spec(spec_cls): def test_init(module): assert isinstance(module, nn.Module) - for attr in ["actor", "behavior", "target_actor", "critics", "target_critics"]: + for attr in "actor behavior target_actor critics target_critics".split(): assert hasattr(module, attr) diff --git a/tests/modules/test_sac.py b/tests/modules/test_sac.py new file mode 100644 index 00000000..4adddb29 --- /dev/null +++ b/tests/modules/test_sac.py @@ -0,0 +1,34 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +import torch.nn as nn + +from raylab.modules.sac import SAC + + +@pytest.fixture +def spec_cls(): + return SAC.spec_cls + + +@pytest.fixture +def module(obs_space, action_space, spec_cls): + return SAC(obs_space, action_space, spec_cls()) + + +def test_spec(spec_cls): + default_config = spec_cls().to_dict() + + for key in ["actor", "critic", "initializer"]: + assert key in default_config + + +def test_init(module): + assert isinstance(module, nn.Module) + + for attr in "actor alpha critics target_critics".split(): + assert hasattr(module, attr) + + +def test_script(module): + torch.jit.script(module) From 214d813e8dd1c7983ad1cef16ae08402c1322a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Mon, 29 Jun 2020 15:29:40 -0300 Subject: [PATCH 22/48] feat(agents): use new DDPG and SAC modules as defaults MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/sac/policy.py | 4 ++-- raylab/agents/sac/trainer.py | 5 +---- raylab/agents/sop/policy.py | 4 ++-- raylab/agents/sop/trainer.py | 2 +- raylab/modules/catalog.py | 3 ++- raylab/policy/target_networks_mixin.py | 4 +++- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/raylab/agents/sac/policy.py b/raylab/agents/sac/policy.py index c2cc425a..fca42f63 100644 --- a/raylab/agents/sac/policy.py +++ b/raylab/agents/sac/policy.py @@ -55,7 +55,7 @@ def make_optimizers(self): components = "actor critics alpha".split() return { - name: build_optimizer(self.module[name], config[name]) + name: build_optimizer(getattr(self.module, name), config[name]) for name in components } @@ -105,6 +105,6 @@ def extra_grad_info(self, component): """Return statistics right after components are updated.""" return { f"grad_norm({component})": nn.utils.clip_grad_norm_( - self.module[component].parameters(), float("inf") + getattr(self.module, component).parameters(), float("inf") ).item() } diff --git a/raylab/agents/sac/trainer.py b/raylab/agents/sac/trainer.py index 5ec3c9b0..9c55404b 100644 --- a/raylab/agents/sac/trainer.py +++ b/raylab/agents/sac/trainer.py @@ -28,10 +28,7 @@ # Interpolation factor in polyak averaging for target networks. "polyak": 0.995, # === Network === - # Size and activation of the fully connected networks computing the logits - # for the policy and action-value function. No layers means the component is - # linear in states and/or actions. - "module": {"type": "SACModule"}, + "module": {"type": "SAC"}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index 522271f8..ef1435e0 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -54,7 +54,7 @@ def make_optimizers(self): components = "actor critics".split() return { - name: build_optimizer(self.module[name], config[name]) + name: build_optimizer(getattr(self.module, name), config[name]) for name in components } @@ -92,7 +92,7 @@ def extra_grad_info(self, component): """Return statistics right after components are updated.""" return { f"grad_norm({component})": nn.utils.clip_grad_norm_( - self.module[component].parameters(), float("inf") + getattr(self.module, component).parameters(), float("inf") ).item() } diff --git a/raylab/agents/sop/trainer.py b/raylab/agents/sop/trainer.py index 3ea295fd..006cc632 100644 --- a/raylab/agents/sop/trainer.py +++ b/raylab/agents/sop/trainer.py @@ -20,7 +20,7 @@ "polyak": 0.995, # Update policy every this number of calls to `learn_on_batch` "policy_delay": 1, - "module": {"type": "DDPGModule"}, + "module": {"type": "DDPG"}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/modules/catalog.py b/raylab/modules/catalog.py index 6d8dad00..9a985c14 100644 --- a/raylab/modules/catalog.py +++ b/raylab/modules/catalog.py @@ -3,6 +3,7 @@ from gym.spaces import Space from .ddpg import DDPG +from .sac import SAC from .v0.ddpg_module import DDPGModule from .v0.maxent_model_based import MaxEntModelBased from .v0.model_based_ddpg import ModelBasedDDPG @@ -38,7 +39,7 @@ ) } -MODULESv1 = {cls.__name__: cls for cls in (DDPG,)} +MODULESv1 = {cls.__name__: cls for cls in (DDPG, SAC)} def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: diff --git a/raylab/policy/target_networks_mixin.py b/raylab/policy/target_networks_mixin.py index 6638e7c0..4028c76b 100644 --- a/raylab/policy/target_networks_mixin.py +++ b/raylab/policy/target_networks_mixin.py @@ -16,5 +16,7 @@ def update_targets(self, module, target_module): target_module (str): name of target module in the policy's module dict """ update_polyak( - self.module[module], self.module[target_module], self.config["polyak"] + getattr(self.module, module), + getattr(self.module, target_module), + self.config["polyak"], ) From ce4204c7d7dd88b90ab606def9f0eaac2d9344b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Mon, 29 Jun 2020 19:20:12 -0300 Subject: [PATCH 23/48] refactor(networks): move mlp encoders to separate module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- .../networks/actor/policy/deterministic.py | 12 +-- .../networks/actor/policy/state_mlp.py | 49 --------- .../networks/actor/policy/stochastic.py | 9 +- raylab/modules/networks/critic/q_value.py | 47 ++------- raylab/modules/networks/mlp.py | 99 +++++++++++++++++++ 5 files changed, 113 insertions(+), 103 deletions(-) delete mode 100644 raylab/modules/networks/actor/policy/state_mlp.py diff --git a/raylab/modules/networks/actor/policy/deterministic.py b/raylab/modules/networks/actor/policy/deterministic.py index 286fcb64..fd618a97 100644 --- a/raylab/modules/networks/actor/policy/deterministic.py +++ b/raylab/modules/networks/actor/policy/deterministic.py @@ -8,9 +8,7 @@ from torch import Tensor import raylab.pytorch.nn as nnx -from raylab.pytorch.nn.init import initialize_ - -from .state_mlp import StateMLP +from raylab.modules.networks.mlp import StateMLP class DeterministicPolicy(nn.Module): @@ -106,7 +104,7 @@ def __init__( mlp_spec: StateMLP.spec_cls, norm_beta: float, ): - encoder = StateMLP(obs_space, mlp_spec).encoder + encoder = StateMLP(obs_space, mlp_spec) action_size = action_space.shape[0] if norm_beta: @@ -122,7 +120,6 @@ def __init__( squash = nnx.TanhSquash(action_low, action_high) super().__init__(encoder, action_linear, squash) - self.mlp_spec = mlp_spec def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. @@ -135,7 +132,4 @@ def initialize_parameters(self, initializer_spec: dict): to the initializer function name in `torch.nn.init` and optional keyword arguments. """ - initializer = initialize_( - activation=self.mlp_spec.activation, **initializer_spec - ) - self.encoder.apply(initializer) + self.encoder.initialize_parameters(initializer_spec) diff --git a/raylab/modules/networks/actor/policy/state_mlp.py b/raylab/modules/networks/actor/policy/state_mlp.py deleted file mode 100644 index add46733..00000000 --- a/raylab/modules/networks/actor/policy/state_mlp.py +++ /dev/null @@ -1,49 +0,0 @@ -# pylint:disable=missing-module-docstring -from dataclasses import dataclass -from dataclasses import field -from typing import List -from typing import Optional - -import torch.nn as nn -from dataclasses_json import DataClassJsonMixin -from gym.spaces import Box -from torch import Tensor - -import raylab.pytorch.nn as nnx - - -@dataclass -class StateMLPSpec(DataClassJsonMixin): - """Specifications for creating a multilayer perceptron. - - Args: - units: Number of units in each hidden layer - activation: Nonlinearity following each linear layer - layer_norm: Whether to apply layer normalization between each linear layer - and following activation - """ - - units: List[int] = field(default_factory=list) - activation: Optional[str] = None - layer_norm: bool = False - - -class StateMLP(nn.Module): - """Multilayer perceptron for encoding state inputs. - - Attributes: - encoder: Fully connected module with multiple layers - """ - - spec_cls = StateMLPSpec - - def __init__(self, obs_space: Box, spec: StateMLPSpec): - super().__init__() - obs_size = obs_space.shape[0] - self.encoder = nnx.FullyConnected( - obs_size, spec.units, spec.activation, layer_norm=spec.layer_norm, - ) - - def forward(self, obs: Tensor) -> Tensor: - # pylint:disable=arguments-differ - return self.encoder(obs) diff --git a/raylab/modules/networks/actor/policy/stochastic.py b/raylab/modules/networks/actor/policy/stochastic.py index 9562af3e..7bde0eba 100644 --- a/raylab/modules/networks/actor/policy/stochastic.py +++ b/raylab/modules/networks/actor/policy/stochastic.py @@ -9,9 +9,7 @@ import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd -from raylab.pytorch.nn.init import initialize_ - -from .state_mlp import StateMLP +from raylab.modules.networks.mlp import StateMLP class StochasticPolicy(nn.Module): @@ -119,7 +117,7 @@ def __init__( params_fn: Callable[[int], nn.Module], dist: ptd.ConditionalDistribution, ): - encoder = StateMLP(obs_space, spec).encoder + encoder = StateMLP(obs_space, spec) params = params_fn(encoder.out_features) params_module = nn.Sequential(encoder, params) super().__init__(params_module, dist) @@ -138,8 +136,7 @@ def initialize_parameters(self, initializer_spec: dict): to the initializer function name in `torch.nn.init` and optional keyword arguments. """ - initializer = initialize_(activation=self.spec.activation, **initializer_spec) - self.encoder.apply(initializer) + self.encoder.initialize_parameters(initializer_spec) class MLPContinuousPolicy(MLPStochasticPolicy): diff --git a/raylab/modules/networks/critic/q_value.py b/raylab/modules/networks/critic/q_value.py index accf86f1..5d35db7d 100644 --- a/raylab/modules/networks/critic/q_value.py +++ b/raylab/modules/networks/critic/q_value.py @@ -1,17 +1,13 @@ """Parameterized action-value estimators.""" -from dataclasses import dataclass -from dataclasses import field -from typing import List -from typing import Optional - import torch import torch.nn as nn -from dataclasses_json import DataClassJsonMixin from gym.spaces import Box from torch import Tensor -import raylab.pytorch.nn as nnx -from raylab.pytorch.nn.init import initialize_ +from raylab.modules.networks.mlp import StateActionMLP + + +MLPSpec = StateActionMLP.spec_cls class QValue(nn.Module): @@ -39,22 +35,6 @@ def forward(self, obs: Tensor, action: Tensor) -> Tensor: return self.value_linear(features) -@dataclass -class StateActionMLPSpec(DataClassJsonMixin): - """Specifications for building an MLP with state and action inputs. - - Args: - units: Number of units in each hidden layer - activation: Nonlinearity following each linear layer - delay_action: Whether to apply an initial preprocessing layer on the - observation before concatenating the action to the input. - """ - - units: List[int] = field(default_factory=list) - activation: Optional[str] = None - delay_action: bool = False - - class MLPQValue(QValue): """Q-value function with a multilayer perceptron encoder. @@ -64,21 +44,11 @@ class MLPQValue(QValue): mlp_spec: Multilayer perceptron specifications """ - spec_cls = StateActionMLPSpec - - def __init__(self, obs_space: Box, action_space: Box, spec: StateActionMLPSpec): - obs_size = obs_space.shape[0] - action_size = action_space.shape[0] + spec_cls = MLPSpec - encoder = nnx.StateActionEncoder( - obs_size, - action_size, - units=spec.units, - activation=spec.activation, - delay_action=spec.delay_action, - ) + def __init__(self, obs_space: Box, action_space: Box, spec: MLPSpec): + encoder = StateActionMLP(obs_space, action_space, spec) super().__init__(encoder) - self.spec = spec def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. @@ -91,8 +61,7 @@ def initialize_parameters(self, initializer_spec: dict): to the initializer function name in `torch.nn.init` and optional keyword arguments. """ - initializer = initialize_(activation=self.spec.activation, **initializer_spec) - self.encoder.apply(initializer) + self.encoder.initialize_parameters(initializer_spec) class QValueEnsemble(nn.ModuleList): diff --git a/raylab/modules/networks/mlp.py b/raylab/modules/networks/mlp.py index 9eb74c0c..17a58f0b 100644 --- a/raylab/modules/networks/mlp.py +++ b/raylab/modules/networks/mlp.py @@ -1,13 +1,112 @@ # pylint:disable=missing-module-docstring +from dataclasses import dataclass +from dataclasses import field from typing import Dict +from typing import List from typing import Optional import torch import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box +import raylab.pytorch.nn as nnx +from raylab.pytorch.nn.init import initialize_ from raylab.pytorch.nn.utils import get_activation +@dataclass +class StateMLPSpec(DataClassJsonMixin): + """Specifications for creating a multilayer perceptron. + + Args: + units: Number of units in each hidden layer + activation: Nonlinearity following each linear layer + layer_norm: Whether to apply layer normalization between each linear layer + and following activation + """ + + units: List[int] = field(default_factory=list) + activation: Optional[str] = None + layer_norm: bool = False + + +class StateMLP(nnx.FullyConnected): + """Multilayer perceptron for encoding state inputs.""" + + spec_cls = StateMLPSpec + + def __init__(self, obs_space: Box, spec: StateMLPSpec): + obs_size = obs_space.shape[0] + super().__init__( + obs_size, spec.units, spec.activation, layer_norm=spec.layer_norm, + ) + self.spec = spec + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + initializer = initialize_(activation=self.spec.activation, **initializer_spec) + self.apply(initializer) + + +@dataclass +class StateActionMLPSpec(DataClassJsonMixin): + """Specifications for building an MLP with state and action inputs. + + Args: + units: Number of units in each hidden layer + activation: Nonlinearity following each linear layer + delay_action: Whether to apply an initial preprocessing layer on the + observation before concatenating the action to the input. + """ + + units: List[int] = field(default_factory=list) + activation: Optional[str] = None + delay_action: bool = False + + +class StateActionMLP(nnx.StateActionEncoder): + """Multilayer perceptron for encoding state-action inputs.""" + + spec_cls = StateActionMLPSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: StateActionMLPSpec): + obs_size = obs_space.shape[0] + action_size = action_space.shape[0] + + super().__init__( + obs_size, + action_size, + units=spec.units, + activation=spec.activation, + delay_action=spec.delay_action, + ) + self.spec = spec + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + initializer = initialize_(activation=self.spec.activation, **initializer_spec) + self.apply(initializer) + + class MLP(nn.Module): """A general purpose Multi-Layer Perceptron.""" From 47eb42e9e3094e57fc1527c2a3cb772b13d04006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 06:16:53 -0300 Subject: [PATCH 24/48] chore(networks): add stochastic models skeleton --- raylab/modules/networks/critic/q_value.py | 3 +- raylab/modules/networks/model/__init__.py | 0 .../networks/model/stochastic/__init__.py | 0 .../networks/model/stochastic/builders.py | 2 + .../networks/model/stochastic/ensemble.py | 75 ++++++++ .../networks/model/stochastic/single.py | 180 ++++++++++++++++++ 6 files changed, 259 insertions(+), 1 deletion(-) create mode 100644 raylab/modules/networks/model/__init__.py create mode 100644 raylab/modules/networks/model/stochastic/__init__.py create mode 100644 raylab/modules/networks/model/stochastic/builders.py create mode 100644 raylab/modules/networks/model/stochastic/ensemble.py create mode 100644 raylab/modules/networks/model/stochastic/single.py diff --git a/raylab/modules/networks/critic/q_value.py b/raylab/modules/networks/critic/q_value.py index 5d35db7d..36914d09 100644 --- a/raylab/modules/networks/critic/q_value.py +++ b/raylab/modules/networks/critic/q_value.py @@ -72,9 +72,10 @@ class QValueEnsemble(nn.ModuleList): """ def __init__(self, q_values): + cls_name = type(self).__name__ assert all( isinstance(q, QValue) for q in q_values - ), """All modules in QValueEnsemble must be instances of QValue.""" + ), f"All modules in {cls_name} must be instances of QValue." super().__init__(q_values) def forward(self, obs: Tensor, action: Tensor, clip: bool = False) -> Tensor: diff --git a/raylab/modules/networks/model/__init__.py b/raylab/modules/networks/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/raylab/modules/networks/model/stochastic/__init__.py b/raylab/modules/networks/model/stochastic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/raylab/modules/networks/model/stochastic/builders.py b/raylab/modules/networks/model/stochastic/builders.py new file mode 100644 index 00000000..d7f799cd --- /dev/null +++ b/raylab/modules/networks/model/stochastic/builders.py @@ -0,0 +1,2 @@ +from .single import MLPModel, ResidualMLPModel +from .ensemble import StochasticModelEnsemble, ForkedStochasticModelEnsemble diff --git a/raylab/modules/networks/model/stochastic/ensemble.py b/raylab/modules/networks/model/stochastic/ensemble.py new file mode 100644 index 00000000..2b41943c --- /dev/null +++ b/raylab/modules/networks/model/stochastic/ensemble.py @@ -0,0 +1,75 @@ +"""Network and configurations for modules with stochastic model ensembles.""" +from typing import List + +import torch +import torch.nn as nn + + +from .single import StochasticModel + + +class StochasticModelEnsemble(nn.ModuleList): + """A static list of stochastic dynamics models. + + Args: + models: List of StochasticModel modules + """ + + # pylint:disable=abstract-method + + def __init__(self, models: List[StochasticModel]): + cls_name = type(self).__name__ + assert all( + isinstance(m, StochasticModel) for m in models + ), f"All modules in {cls_name} must be instances of StochasticModel." + super().__init__(models) + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + """Compute samples and likelihoods for each model in the ensemble.""" + outputs = [m.sample(obs, action, sample_shape) for m in self] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + """Compute reparemeterized samples and likelihoods for each model.""" + outputs = [m.rsample(obs, action, sample_shape) for m in self] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + """Compute likelihoods for each model in the ensemble.""" + return torch.stack([m.log_prob(obs, action, next_obs) for m in self]) + + +class ForkedStochasticModelEnsemble(StochasticModelEnsemble): + """Ensemble of stochastic models with parallelized methods.""" + + # pylint:disable=abstract-method,protected-access + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + futures = [torch.jit._fork(m.sample, (obs, action, sample_shape)) for m in self] + outputs = [torch.jit._wait(f) for f in futures] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + futures = [ + torch.jit._fork(m.rsample, (obs, action, sample_shape)) for m in self + ] + outputs = [torch.jit._wait(f) for f in futures] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + futures = [torch.jit._fork(m.log_prob, (obs, action, next_obs)) for m in self] + return torch.stack([torch.jit._wait(f) for f in futures]) diff --git a/raylab/modules/networks/model/stochastic/single.py b/raylab/modules/networks/model/stochastic/single.py new file mode 100644 index 00000000..df7bfb3d --- /dev/null +++ b/raylab/modules/networks/model/stochastic/single.py @@ -0,0 +1,180 @@ +"""NN modules for stochastic dynamics estimation.""" +from dataclasses import dataclass +from dataclasses import field +from typing import List + +import torch +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +import rayalb.pytorch.nn as nnx +import raylab.pytorch.nn.distributions as ptd +from raylab.modules.networks.mlp import StateActionMLP + + +class StochasticModel(nn.Module): + """Represents a stochastic model as a conditional distribution module.""" + + def __init__( + self, params_module: nn.Module, dist_module: ptd.ConditionalDistribution + ): + super().__init__() + self.params = params_module + self.dist = dist_module + + def forward(self, obs, action): # pylint:disable=arguments-differ + return self.params(obs, action) + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. Returns a (sample, log_prob) + pair. + """ + params = self(obs, action) + return self.dist.sample(params, sample_shape) + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. Returns a (rsample, log_prob) pair. + """ + params = self(obs, action) + return self.dist.rsample(params, sample_shape) + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + """ + Returns the log probability density/mass function evaluated at `next_obs`. + """ + params = self(obs, action) + return self.dist.log_prob(next_obs, params) + + @torch.jit.export + def cdf(self, obs, action, next_obs): + """Returns the cumulative density/mass function evaluated at `next_obs`.""" + params = self(obs, action) + return self.dist.cdf(next_obs, params) + + @torch.jit.export + def icdf(self, obs, action, prob): + """Returns the inverse cumulative density/mass function evaluated at `prob`.""" + params = self(obs, action) + return self.dist.icdf(prob, params) + + @torch.jit.export + def entropy(self, obs, action): + """Returns entropy of distribution.""" + params = self(obs, action) + return self.dist.entropy(params) + + @torch.jit.export + def perplexity(self, obs, action): + """Returns perplexity of distribution.""" + params = self(obs, action) + return self.dist.perplexity(params) + + @torch.jit.export + def reproduce(self, obs, action, next_obs): + """Produce a reparametrized sample with the same value as `next_obs`.""" + params = self(obs, action) + return self.dist.reproduce(next_obs, params) + + +class ResidualMixin: + """Overrides StochasticModel interface to model state transition residuals.""" + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + params = self(obs, action) + res, log_prob = self.dist.sample(params, sample_shape) + return obs + res, log_prob + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + params = self(obs, action) + res, log_prob = self.dist.rsample(params, sample_shape) + return obs + res, log_prob + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + params = self(obs, action) + return self.dist.log_prob(next_obs - obs, params) + + @torch.jit.export + def cdf(self, obs, action, next_obs): + params = self(obs, action) + return self.dist.cdf(next_obs - obs, params) + + @torch.jit.export + def icdf(self, obs, action, prob): + params = self(obs, action) + residual = self.dist.icdf(prob, params) + return obs + residual + + @torch.jit.export + def reproduce(self, obs, action, next_obs): + params = self(obs, action) + sample_, log_prob_ = self.dist.reproduce(next_obs - obs, params) + return obs + sample_, log_prob_ + + +class DynamicsParams(nn.Module): + """Neural network mapping state-action pairs to distribution parameters. + + Args: + encoder: Module mapping state-action pairs to 1D features + params: Module mapping 1D features to distribution parameters + """ + + def __init__(self, encoder: nn.Module, params: nn.Module): + super().__init__() + self.encoder = encoder + self.params = params + + def forward(self, obs, actions): # pylint:disable=arguments-differ + return self.params(self.encoder(obs, actions)) + + +MLPSpec = StateActionMLP.spec_cls + + +@dataclass +class MLPModelSpec(DataClassJsonMixin): + """Specifications for stochastic model networks. + + Args: + mlp: Specifications for building an MLP with state and action inputs + input_dependent_scale: Whether to parameterize the Gaussian standard + deviation as a function of the state and action + """ + + mlp: MLPSpec = field(default_factory=MLPSpec) + input_dependent_scale: bool = True + + +class MLPModel(StochasticModel): + """Stochastic model with multilayer perceptron state-action encoder.""" + + spec_cls = MLPModelSpec + + def __init__( + self, obs_space: Box, action_space: Box, spec: MLPModelSpec, + ): + encoder = StateActionMLP(obs_space, action_space, spec.mlp) + params = nnx.NormalParams( + encoder.out_features, + obs_space.shape[0], + input_dependent_scale=spec.input_dependent_scale, + ) + params = DynamicsParams(encoder, params) + dist = ptd.Independent(ptd.Normal(), reinterpreted_batch_ndims=1) + super().__init__(params, dist) + + +class ResidualMLPModel(ResidualMixin, MLPModel): + """Residual stochastic multilayer perceptron model.""" From 5ae70c1c9bcdf2dd3977a7e099d1a1258391bd57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 08:06:36 -0300 Subject: [PATCH 25/48] feat(networks): add stochastic models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/catalog.py | 4 +- raylab/modules/ddpg.py | 16 ++-- raylab/modules/mbddpg.py | 61 +++++++++++++ raylab/modules/mbsac.py | 59 ++++++++++++ .../networks/actor/policy/deterministic.py | 3 - .../networks/actor/policy/stochastic.py | 3 - .../networks/model/stochastic/__init__.py | 4 + .../networks/model/stochastic/builders.py | 89 ++++++++++++++++++- .../networks/model/stochastic/single.py | 15 +++- raylab/modules/sac.py | 16 ++-- 10 files changed, 246 insertions(+), 24 deletions(-) create mode 100644 raylab/modules/mbddpg.py create mode 100644 raylab/modules/mbsac.py diff --git a/raylab/modules/catalog.py b/raylab/modules/catalog.py index 9a985c14..4953cff3 100644 --- a/raylab/modules/catalog.py +++ b/raylab/modules/catalog.py @@ -3,6 +3,8 @@ from gym.spaces import Space from .ddpg import DDPG +from .mbddpg import MBDDPG +from .mbsac import MBSAC from .sac import SAC from .v0.ddpg_module import DDPGModule from .v0.maxent_model_based import MaxEntModelBased @@ -39,7 +41,7 @@ ) } -MODULESv1 = {cls.__name__: cls for cls in (DDPG, SAC)} +MODULESv1 = {cls.__name__: cls for cls in (DDPG, MBDDPG, MBSAC, SAC)} def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: diff --git a/raylab/modules/ddpg.py b/raylab/modules/ddpg.py index 3e45a509..3ac76b42 100644 --- a/raylab/modules/ddpg.py +++ b/raylab/modules/ddpg.py @@ -22,13 +22,21 @@ class DDPGSpec(DataClassJsonMixin): critic: Specifications for action-value estimators initializer: Optional dictionary with mandatory `type` key corresponding to the initializer function name in `torch.nn.init` and optional - keyword arguments. + keyword arguments. Overrides actor and critic initializer + specifications. """ actor: ActorSpec = field(default_factory=ActorSpec) critic: CriticSpec = field(default_factory=CriticSpec) initializer: dict = field(default_factory=dict) + def __post_init__(self): + # Top-level initializer options take precedence over individual + # component's options + if self.initializer: + self.actor.initializer = self.initializer + self.critic.initializer = self.initializer + class DDPG(nn.Module): """NN module for DDPG-like algorithms. @@ -54,12 +62,6 @@ class DDPG(nn.Module): def __init__(self, obs_space: Box, action_space: Box, spec: DDPGSpec): super().__init__() - # Top-level initializer options take precedence over individual - # component's options - if spec.initializer: - spec.actor.initializer = spec.initializer - spec.critic.initializer = spec.initializer - # Build actor actor = DeterministicActor(obs_space, action_space, spec.actor) self.actor = actor.policy diff --git a/raylab/modules/mbddpg.py b/raylab/modules/mbddpg.py new file mode 100644 index 00000000..09174f0d --- /dev/null +++ b/raylab/modules/mbddpg.py @@ -0,0 +1,61 @@ +"""Network and configurations for model-based DDPG algorithms.""" +from dataclasses import dataclass +from dataclasses import field + +from gym.spaces import Box + +from .ddpg import DDPG +from .ddpg import DDPGSpec +from .networks.model.stochastic import build_ensemble +from .networks.model.stochastic import EnsembleSpec + + +@dataclass +class MBDDPGSpec(DDPGSpec): + """Specifications for model-based DDPG modules. + + Args: + model: Specifications for stochastic dynamics model ensemble + actor: Specifications for policy, behavior, and target policy + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Overrides model, actor, and critic initializer + specifications. + """ + + model: EnsembleSpec = field(default_factory=EnsembleSpec) + + def __post_init__(self): + super().__post_init__() + if self.initializer: + self.model.initializer = self.initializer + + +class MBDDPG(DDPG): + """NN module for Model-Based DDPG algorithms. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for model-based DDPG modules + + Attributes: + model (StochasticModelEnsemble): Stochastic dynamics model ensemble + actor (DeterministicPolicy): The deterministic policy to be learned + behavior (DeterministicPolicy): The policy for exploration + target_actor (DeterministicPolicy): The policy used for estimating the + arg max in Q-Learning + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = MBDDPGSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: MBDDPGSpec): + super().__init__(obs_space, action_space, spec) + + self.models = build_ensemble(obs_space, action_space, spec.model) diff --git a/raylab/modules/mbsac.py b/raylab/modules/mbsac.py new file mode 100644 index 00000000..0d1be992 --- /dev/null +++ b/raylab/modules/mbsac.py @@ -0,0 +1,59 @@ +"""Network and configurations for model-based SAC algorithms.""" +from dataclasses import dataclass +from dataclasses import field + +from gym.spaces import Box + +from .networks.model.stochastic import build_ensemble +from .networks.model.stochastic import EnsembleSpec +from .sac import SAC +from .sac import SACSpec + + +@dataclass +class MBSACSpec(SACSpec): + """Specifications for model-based SAC modules. + + Args: + model: Specifications for stochastic dynamics model ensemble + actor: Specifications for stochastic policy and entropy coefficient + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Overrides model, actor, and critic initializer + specifications. + """ + + model: EnsembleSpec = field(default_factory=EnsembleSpec) + + def __post_init__(self): + super().__post_init__() + if self.initializer: + self.model.initializer = self.initializer + + +class MBSAC(SAC): + """NN module for Model-Based Soft Actor-Critic algorithms. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for model-based SAC modules + + Attributes: + model (StochasticModelEnsemble): Stochastic dynamics model ensemble + actor (StochasticPolicy): Stochastic policy to be learned + alpha (Alpha): Entropy bonus coefficient + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = MBSACSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: MBSACSpec): + super().__init__(obs_space, action_space, spec) + + self.models = build_ensemble(obs_space, action_space, spec.model) diff --git a/raylab/modules/networks/actor/policy/deterministic.py b/raylab/modules/networks/actor/policy/deterministic.py index fd618a97..0c5ca32a 100644 --- a/raylab/modules/networks/actor/policy/deterministic.py +++ b/raylab/modules/networks/actor/policy/deterministic.py @@ -124,9 +124,6 @@ def __init__( def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. - Uses `raylab.pytorch.nn.init.initialize_` to create an initializer - function. - Args: initializer_spec: Dictionary with mandatory `type` key corresponding to the initializer function name in `torch.nn.init` and optional diff --git a/raylab/modules/networks/actor/policy/stochastic.py b/raylab/modules/networks/actor/policy/stochastic.py index 7bde0eba..826d0f73 100644 --- a/raylab/modules/networks/actor/policy/stochastic.py +++ b/raylab/modules/networks/actor/policy/stochastic.py @@ -128,9 +128,6 @@ def __init__( def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. - Uses `raylab.pytorch.nn.init.initialize_` to create an initializer - function. - Args: initializer_spec: Dictionary with mandatory `type` key corresponding to the initializer function name in `torch.nn.init` and optional diff --git a/raylab/modules/networks/model/stochastic/__init__.py b/raylab/modules/networks/model/stochastic/__init__.py index e69de29b..342df363 100644 --- a/raylab/modules/networks/model/stochastic/__init__.py +++ b/raylab/modules/networks/model/stochastic/__init__.py @@ -0,0 +1,4 @@ +"""Implementations of stochastic dynamics models.""" + +from .builders import build_ensemble +from .builders import EnsembleSpec diff --git a/raylab/modules/networks/model/stochastic/builders.py b/raylab/modules/networks/model/stochastic/builders.py index d7f799cd..5de0693f 100644 --- a/raylab/modules/networks/model/stochastic/builders.py +++ b/raylab/modules/networks/model/stochastic/builders.py @@ -1,2 +1,87 @@ -from .single import MLPModel, ResidualMLPModel -from .ensemble import StochasticModelEnsemble, ForkedStochasticModelEnsemble +"""Constructors for stochastic dynamics models.""" +from dataclasses import dataclass +from dataclasses import field + +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .ensemble import ForkedStochasticModelEnsemble +from .ensemble import StochasticModelEnsemble +from .single import MLPModel +from .single import ResidualMLPModel + +ModelSpec = MLPModel.spec_cls + + +@dataclass +class Spec(DataClassJsonMixin): + """Specifications for stochastic dynamics model. + + Args: + network: Specifications for stochastic model network + residual: Whether to build model as a residual one, i.e., that + predicts the change in state rather than the next state itself + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Used to initialize the model's Linear layers. + """ + + network: ModelSpec = field(default_factory=ModelSpec) + residual: bool = True + initializer: dict = field(default_factory=dict) + + +def build(obs_space: Box, action_space: Box, spec: Spec) -> MLPModel: + """Construct stochastic dynamics model. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for stochastic dynamics model + + Returns: + A stochastic dynamics model + """ + cls = ResidualMLPModel if spec.residual else MLPModel + model = cls(obs_space, action_space, spec.network) + model.initialize_parameters(spec.initializer) + return model + + +@dataclass +class EnsembleSpec(Spec): + """Specifications for stochastic dynamics model ensemble. + + Args: + network: Specifications for stochastic model networks + ensemble_size: Number of models in the collection. + residual: Whether to build each model as a residual one, i.e., that + predicts the change in state rather than the next state itself + parallelize: Whether to use an ensemble with parallelized `sample`, + `rsample`, and `log_prob` methods + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Used to initialize the models' Linear layers. + """ + + ensemble_size: int = 1 + parallelize: bool = False + + +def build_ensemble( + obs_space: Box, action_space: Box, spec: EnsembleSpec +) -> StochasticModelEnsemble: + """Construct stochastic dynamics model ensemble. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for stochastic dynamics model ensemble + + Returns: + A stochastic dynamics model ensemble + """ + models = [build(obs_space, action_space, spec) for _ in range(spec.ensemble_size)] + cls = ForkedStochasticModelEnsemble if spec.parallelize else StochasticModelEnsemble + ensemble = cls(models) + return ensemble diff --git a/raylab/modules/networks/model/stochastic/single.py b/raylab/modules/networks/model/stochastic/single.py index df7bfb3d..1ece4f65 100644 --- a/raylab/modules/networks/model/stochastic/single.py +++ b/raylab/modules/networks/model/stochastic/single.py @@ -8,7 +8,7 @@ from dataclasses_json import DataClassJsonMixin from gym.spaces import Box -import rayalb.pytorch.nn as nnx +import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd from raylab.modules.networks.mlp import StateActionMLP @@ -88,6 +88,8 @@ def reproduce(self, obs, action, next_obs): class ResidualMixin: """Overrides StochasticModel interface to model state transition residuals.""" + # pylint:disable=missing-function-docstring,not-callable + @torch.jit.export def sample(self, obs, action, sample_shape: List[int] = ()): params = self(obs, action) @@ -174,6 +176,17 @@ def __init__( params = DynamicsParams(encoder, params) dist = ptd.Independent(ptd.Normal(), reinterpreted_batch_ndims=1) super().__init__(params, dist) + self.encoder = encoder + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all encoder parameters. + + Args: + initializer_spec: Dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + self.encoder.initialize_parameters(initializer_spec) class ResidualMLPModel(ResidualMixin, MLPModel): diff --git a/raylab/modules/sac.py b/raylab/modules/sac.py index 5ef60232..bbf1d1f7 100644 --- a/raylab/modules/sac.py +++ b/raylab/modules/sac.py @@ -22,13 +22,21 @@ class SACSpec(DataClassJsonMixin): critic: Specifications for action-value estimators initializer: Optional dictionary with mandatory `type` key corresponding to the initializer function name in `torch.nn.init` and optional - keyword arguments. + keyword arguments. Overrides actor and critic initializer + specifications. """ actor: ActorSpec = field(default_factory=ActorSpec) critic: CriticSpec = field(default_factory=CriticSpec) initializer: dict = field(default_factory=dict) + def __post_init__(self): + # Top-level initializer options take precedence over individual + # component's options + if self.initializer: + self.actor.initializer = self.initializer + self.critic.initializer = self.initializer + class SAC(nn.Module): """NN module for Soft Actor-Critic algorithms. @@ -52,12 +60,6 @@ class SAC(nn.Module): def __init__(self, obs_space: Box, action_space: Box, spec: SACSpec): super().__init__() - # Top-level initializer options take precedence over individual - # component's options - if spec.initializer: - spec.actor.initializer = spec.initializer - spec.critic.initializer = spec.initializer - actor = StochasticActor(obs_space, action_space, spec.actor) self.actor = actor.policy self.alpha = actor.alpha From e5150c71844bd01c26780ff3b90dd5b9bc9a62a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 08:38:55 -0300 Subject: [PATCH 26/48] refactor(networks): move input_dependent_scale to model builders MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- .../networks/model/stochastic/builders.py | 5 +++- .../networks/model/stochastic/single.py | 27 +++++-------------- 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/raylab/modules/networks/model/stochastic/builders.py b/raylab/modules/networks/model/stochastic/builders.py index 5de0693f..aa90d52a 100644 --- a/raylab/modules/networks/model/stochastic/builders.py +++ b/raylab/modules/networks/model/stochastic/builders.py @@ -19,6 +19,8 @@ class Spec(DataClassJsonMixin): Args: network: Specifications for stochastic model network + input_dependent_scale: Whether to parameterize the Gaussian standard + deviation as a function of the state and action residual: Whether to build model as a residual one, i.e., that predicts the change in state rather than the next state itself initializer: Optional dictionary with mandatory `type` key corresponding @@ -27,6 +29,7 @@ class Spec(DataClassJsonMixin): """ network: ModelSpec = field(default_factory=ModelSpec) + input_dependent_scale: bool = True residual: bool = True initializer: dict = field(default_factory=dict) @@ -43,7 +46,7 @@ def build(obs_space: Box, action_space: Box, spec: Spec) -> MLPModel: A stochastic dynamics model """ cls = ResidualMLPModel if spec.residual else MLPModel - model = cls(obs_space, action_space, spec.network) + model = cls(obs_space, action_space, spec.network, spec.input_dependent_scale) model.initialize_parameters(spec.initializer) return model diff --git a/raylab/modules/networks/model/stochastic/single.py b/raylab/modules/networks/model/stochastic/single.py index 1ece4f65..ad42ae10 100644 --- a/raylab/modules/networks/model/stochastic/single.py +++ b/raylab/modules/networks/model/stochastic/single.py @@ -1,11 +1,8 @@ """NN modules for stochastic dynamics estimation.""" -from dataclasses import dataclass -from dataclasses import field from typing import List import torch import torch.nn as nn -from dataclasses_json import DataClassJsonMixin from gym.spaces import Box import raylab.pytorch.nn as nnx @@ -145,33 +142,23 @@ def forward(self, obs, actions): # pylint:disable=arguments-differ MLPSpec = StateActionMLP.spec_cls -@dataclass -class MLPModelSpec(DataClassJsonMixin): - """Specifications for stochastic model networks. - - Args: - mlp: Specifications for building an MLP with state and action inputs - input_dependent_scale: Whether to parameterize the Gaussian standard - deviation as a function of the state and action - """ - - mlp: MLPSpec = field(default_factory=MLPSpec) - input_dependent_scale: bool = True - - class MLPModel(StochasticModel): """Stochastic model with multilayer perceptron state-action encoder.""" - spec_cls = MLPModelSpec + spec_cls = MLPSpec def __init__( - self, obs_space: Box, action_space: Box, spec: MLPModelSpec, + self, + obs_space: Box, + action_space: Box, + spec: MLPSpec, + input_dependent_scale: bool, ): encoder = StateActionMLP(obs_space, action_space, spec.mlp) params = nnx.NormalParams( encoder.out_features, obs_space.shape[0], - input_dependent_scale=spec.input_dependent_scale, + input_dependent_scale=input_dependent_scale, ) params = DynamicsParams(encoder, params) dist = ptd.Independent(ptd.Normal(), reinterpreted_batch_ndims=1) From 25fb7e48724696f5913d23c2d951a0dedc7cd287 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 08:50:46 -0300 Subject: [PATCH 27/48] test(networks): add stochastic model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- tests/modules/networks/model/__init__.py | 0 .../networks/model/stochastic/__init__.py | 0 .../networks/model/stochastic/conftest.py | 16 +++ .../model/stochastic/test_ensemble.py | 46 ++++++++ .../networks/model/stochastic/test_single.py | 102 ++++++++++++++++++ 5 files changed, 164 insertions(+) create mode 100644 tests/modules/networks/model/__init__.py create mode 100644 tests/modules/networks/model/stochastic/__init__.py create mode 100644 tests/modules/networks/model/stochastic/conftest.py create mode 100644 tests/modules/networks/model/stochastic/test_ensemble.py create mode 100644 tests/modules/networks/model/stochastic/test_single.py diff --git a/tests/modules/networks/model/__init__.py b/tests/modules/networks/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/networks/model/stochastic/__init__.py b/tests/modules/networks/model/stochastic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/networks/model/stochastic/conftest.py b/tests/modules/networks/model/stochastic/conftest.py new file mode 100644 index 00000000..f544d2cc --- /dev/null +++ b/tests/modules/networks/model/stochastic/conftest.py @@ -0,0 +1,16 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +from ray.rllib import SampleBatch + + +@pytest.fixture +def log_prob_inputs(batch): + return [ + batch[k] + for k in (SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS) + ] + + +@pytest.fixture +def sample_inputs(batch): + return [batch[k] for k in (SampleBatch.CUR_OBS, SampleBatch.ACTIONS)] diff --git a/tests/modules/networks/model/stochastic/test_ensemble.py b/tests/modules/networks/model/stochastic/test_ensemble.py new file mode 100644 index 00000000..b57682e1 --- /dev/null +++ b/tests/modules/networks/model/stochastic/test_ensemble.py @@ -0,0 +1,46 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch + + +@pytest.fixture(scope="module", params=(True, False), ids=lambda x: f"Forked({x})") +def module_cls(request): + from raylab.modules.networks.model.stochastic.ensemble import ( + StochasticModelEnsemble, + ) + from raylab.modules.networks.model.stochastic.ensemble import ( + ForkedStochasticModelEnsemble, + ) + + return ForkedStochasticModelEnsemble if request.param else StochasticModelEnsemble + + +@pytest.fixture(params=(1, 4), ids=lambda x: f"Ensemble({x})") +def ensemble_size(request): + return request.param + + +@pytest.fixture +def build_single(obs_space, action_space): + from raylab.modules.networks.model.stochastic.single import MLPModel + + spec = MLPModel.spec_cls() + input_dependent_scale = True + + return lambda: MLPModel(obs_space, action_space, spec, input_dependent_scale) + + +@pytest.fixture +def module(module_cls, build_single, ensemble_size, torch_script): + models = [build_single() for _ in range(ensemble_size)] + + module = module_cls(models) + return torch.jit.script(module) if torch_script else module + + +def test_log_prob(module, log_prob_inputs, ensemble_size): + obs = log_prob_inputs[0] + log_prob = module.log_prob(*log_prob_inputs) + + assert torch.is_tensor(log_prob) + assert log_prob.shape == (ensemble_size,) + obs.shape[:-1] diff --git a/tests/modules/networks/model/stochastic/test_single.py b/tests/modules/networks/model/stochastic/test_single.py new file mode 100644 index 00000000..be87c779 --- /dev/null +++ b/tests/modules/networks/model/stochastic/test_single.py @@ -0,0 +1,102 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module", params=(True, False), ids=lambda x: f"Residual({x})") +def module_cls(request): + from raylab.modules.networks.model.stochastic.single import MLPModel + from raylab.modules.networks.model.stochastic.single import ResidualMLPModel + + return ResidualMLPModel if request.param else MLPModel + + +@pytest.fixture +def spec(module_cls): + return module_cls.spec_cls() + + +@pytest.fixture(params=(True, False), ids=lambda x: f"InputDependentScale({x})") +def input_dependent_scale(request): + return request.param + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec, input_dependent_scale): + return module_cls(obs_space, action_space, spec, input_dependent_scale) + + +def test_sample(module, batch): + new_obs = batch[SampleBatch.NEXT_OBS] + sampler = module.rsample + inputs = (batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS]) + + samples, logp = sampler(*inputs) + samples_, _ = sampler(*inputs) + assert samples.shape == new_obs.shape + assert samples.dtype == new_obs.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_params(module, batch): + inputs = (batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS]) + new_obs = batch[SampleBatch.NEXT_OBS] + + params = module(*inputs) + assert "loc" in params + assert "scale" in params + + loc, scale = params["loc"], params["scale"] + assert loc.shape == new_obs.shape + assert scale.shape == new_obs.shape + assert loc.dtype == torch.float32 + assert scale.dtype == torch.float32 + + params = set(module.parameters()) + for par in params: + par.grad = None + loc.mean().backward() + assert any(p.grad is not None for p in params) + + for par in params: + par.grad = None + module(*inputs)["scale"].mean().backward() + assert any(p.grad is not None for p in params) + + +def test_log_prob(module, batch): + logp = module.log_prob( + batch[SampleBatch.CUR_OBS], + batch[SampleBatch.ACTIONS], + batch[SampleBatch.NEXT_OBS], + ) + + assert torch.is_tensor(logp) + assert logp.shape == batch[SampleBatch.REWARDS].shape + + logp.sum().backward() + assert all(p.grad is not None for p in module.parameters()) + + +def test_reproduce(module, batch): + obs, act, new_obs = [ + batch[k] + for k in (SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS) + ] + + new_obs_, logp_ = module.reproduce(obs, act, new_obs) + assert new_obs_.shape == new_obs.shape + assert new_obs_.dtype == new_obs.dtype + assert torch.allclose(new_obs_, new_obs, atol=1e-5) + assert logp_.shape == batch[SampleBatch.REWARDS].shape + + new_obs_.mean().backward() + params = set(module.parameters()) + assert all(p.grad is not None for p in params) + + +def test_script(module): + torch.jit.script(module) From 9ff72abbc43d09b3cc65ff5e62a0746bc4c1be77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 08:51:50 -0300 Subject: [PATCH 28/48] fix(networks): correctly pass arguments torch jit._fork MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/networks/critic/q_value.py | 2 +- .../modules/networks/model/stochastic/ensemble.py | 9 +++------ raylab/modules/networks/model/stochastic/single.py | 2 +- tests/modules/networks/critic/test_action_value.py | 13 +++++++++++-- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/raylab/modules/networks/critic/q_value.py b/raylab/modules/networks/critic/q_value.py index 36914d09..d722feaa 100644 --- a/raylab/modules/networks/critic/q_value.py +++ b/raylab/modules/networks/critic/q_value.py @@ -113,7 +113,7 @@ class ForkedQValueEnsemble(QValueEnsemble): def forward(self, obs: Tensor, action: Tensor, clip: bool = False) -> Tensor: # pylint:disable=protected-access - futures = [torch.jit._fork(m, (obs, action)) for m in self] + futures = [torch.jit._fork(m, obs, action) for m in self] action_values = torch.cat([torch.jit._wait(f) for f in futures], dim=-1) if clip: action_values, _ = action_values.min(keepdim=True, dim=-1) diff --git a/raylab/modules/networks/model/stochastic/ensemble.py b/raylab/modules/networks/model/stochastic/ensemble.py index 2b41943c..dae46a2a 100644 --- a/raylab/modules/networks/model/stochastic/ensemble.py +++ b/raylab/modules/networks/model/stochastic/ensemble.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn - from .single import StochasticModel @@ -53,7 +52,7 @@ class ForkedStochasticModelEnsemble(StochasticModelEnsemble): @torch.jit.export def sample(self, obs, action, sample_shape: List[int] = ()): - futures = [torch.jit._fork(m.sample, (obs, action, sample_shape)) for m in self] + futures = [torch.jit._fork(m.sample, obs, action, sample_shape) for m in self] outputs = [torch.jit._wait(f) for f in futures] sample = torch.stack([s for s, _ in outputs]) logp = torch.stack([p for _, p in outputs]) @@ -61,9 +60,7 @@ def sample(self, obs, action, sample_shape: List[int] = ()): @torch.jit.export def rsample(self, obs, action, sample_shape: List[int] = ()): - futures = [ - torch.jit._fork(m.rsample, (obs, action, sample_shape)) for m in self - ] + futures = [torch.jit._fork(m.rsample, obs, action, sample_shape) for m in self] outputs = [torch.jit._wait(f) for f in futures] sample = torch.stack([s for s, _ in outputs]) logp = torch.stack([p for _, p in outputs]) @@ -71,5 +68,5 @@ def rsample(self, obs, action, sample_shape: List[int] = ()): @torch.jit.export def log_prob(self, obs, action, next_obs): - futures = [torch.jit._fork(m.log_prob, (obs, action, next_obs)) for m in self] + futures = [torch.jit._fork(m.log_prob, obs, action, next_obs) for m in self] return torch.stack([torch.jit._wait(f) for f in futures]) diff --git a/raylab/modules/networks/model/stochastic/single.py b/raylab/modules/networks/model/stochastic/single.py index ad42ae10..dcbeb749 100644 --- a/raylab/modules/networks/model/stochastic/single.py +++ b/raylab/modules/networks/model/stochastic/single.py @@ -154,7 +154,7 @@ def __init__( spec: MLPSpec, input_dependent_scale: bool, ): - encoder = StateActionMLP(obs_space, action_space, spec.mlp) + encoder = StateActionMLP(obs_space, action_space, spec) params = nnx.NormalParams( encoder.out_features, obs_space.shape[0], diff --git a/tests/modules/networks/critic/test_action_value.py b/tests/modules/networks/critic/test_action_value.py index fe4a0510..4fd5f89b 100644 --- a/tests/modules/networks/critic/test_action_value.py +++ b/tests/modules/networks/critic/test_action_value.py @@ -16,9 +16,14 @@ def double_q(request): return request.param +@pytest.fixture(params=(True, False), ids=lambda x: "Parallelize({x})") +def parallelize(request): + return request.param + + @pytest.fixture -def spec(module_cls, double_q): - return module_cls.spec_cls(double_q=double_q) +def spec(module_cls, double_q, parallelize): + return module_cls.spec_cls(double_q=double_q, parallelize=parallelize) @pytest.fixture @@ -48,3 +53,7 @@ def test_module_creation(module, batch, spec): torch.allclose(p, t) for p, t in zip(q_values.parameters(), targets.parameters()) ) + + +def test_script(module): + torch.jit.script(module) From 48f197c140da47f6757655fd2ace9ee8cf872771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 09:03:44 -0300 Subject: [PATCH 29/48] feat(agents): use new model-based DDPG and SAC modules as defaults MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/mage/trainer.py | 2 +- raylab/agents/mapo/trainer.py | 11 ++++++++++- raylab/agents/mbpo/policy.py | 2 +- raylab/agents/mbpo/trainer.py | 8 ++++++-- raylab/modules/mbsac.py | 2 +- tests/agents/mbpo/test_policy.py | 8 +++----- 6 files changed, 22 insertions(+), 11 deletions(-) diff --git a/raylab/agents/mage/trainer.py b/raylab/agents/mage/trainer.py index 4975b776..1761bae3 100644 --- a/raylab/agents/mage/trainer.py +++ b/raylab/agents/mage/trainer.py @@ -39,7 +39,7 @@ patience_epochs=None, improvement_threshold=None, ).to_dict(), - "module": {"type": "ModelBasedDDPG", "model": {"ensemble_size": 1}}, + "module": {"type": "MBDDPG"}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/agents/mapo/trainer.py b/raylab/agents/mapo/trainer.py index 3d03926e..d64eae94 100644 --- a/raylab/agents/mapo/trainer.py +++ b/raylab/agents/mapo/trainer.py @@ -12,7 +12,16 @@ DEFAULT_CONFIG = with_base_config( { # === MAPOTorchPolicy === - "module": {"type": "ModelBasedSAC", "model": {"ensemble_size": 1}}, + "module": { + "type": "MBSAC", + "model": { + "network": {"units": (128, 128), "activation": "Swish"}, + "ensemble_size": 1, + "input_dependent_scale": True, + "parallelize": False, + "residual": True, + }, + }, "losses": { # Gradient estimator for optimizing expectations. Possible types include # SF: score function diff --git a/raylab/agents/mbpo/policy.py b/raylab/agents/mbpo/policy.py index 17beb491..774cce94 100644 --- a/raylab/agents/mbpo/policy.py +++ b/raylab/agents/mbpo/policy.py @@ -37,6 +37,6 @@ def make_optimizers(self): components = "models actor critics alpha".split() return { - name: build_optimizer(self.module[name], config[name]) + name: build_optimizer(getattr(self.module, name), config[name]) for name in components } diff --git a/raylab/agents/mbpo/trainer.py b/raylab/agents/mbpo/trainer.py index 30c49918..3bbffabe 100644 --- a/raylab/agents/mbpo/trainer.py +++ b/raylab/agents/mbpo/trainer.py @@ -10,11 +10,13 @@ { # === MBPOTorchPolicy === "module": { - "type": "ModelBasedSAC", + "type": "MBSAC", "model": { - "encoder": {"units": (128, 128), "activation": "Swish"}, + "network": {"units": (128, 128), "activation": "Swish"}, "ensemble_size": 7, "input_dependent_scale": True, + "parallelize": True, + "residual": True, }, "actor": { "encoder": {"units": (128, 128), "activation": "Swish"}, @@ -48,6 +50,8 @@ "learning_starts": 5000, # === OffPolicyTrainer === "train_batch_size": 512, + # === Trainer === + "compile_policy": True, } ) diff --git a/raylab/modules/mbsac.py b/raylab/modules/mbsac.py index 0d1be992..5cb620a0 100644 --- a/raylab/modules/mbsac.py +++ b/raylab/modules/mbsac.py @@ -41,7 +41,7 @@ class MBSAC(SAC): spec: Specifications for model-based SAC modules Attributes: - model (StochasticModelEnsemble): Stochastic dynamics model ensemble + models (StochasticModelEnsemble): Stochastic dynamics model ensemble actor (StochasticPolicy): Stochastic policy to be learned alpha (Alpha): Entropy bonus coefficient critics (QValueEnsemble): The action-value estimators to be learned diff --git a/tests/agents/mbpo/test_policy.py b/tests/agents/mbpo/test_policy.py index f887e931..e792ed94 100644 --- a/tests/agents/mbpo/test_policy.py +++ b/tests/agents/mbpo/test_policy.py @@ -22,7 +22,7 @@ def config(ensemble_size): "patience_epochs": 5, }, "model_sampling": {"rollout_schedule": [(0, 10)], "num_elites": 1}, - "module": {"ensemble_size": ensemble_size}, + "module": {"model": {"ensemble_size": ensemble_size}}, } @@ -32,10 +32,8 @@ def policy(policy_cls, config): def test_policy_creation(policy): - assert "models" in policy.module - assert "actor" in policy.module - assert "critics" in policy.module - assert "alpha" in policy.module + for attr in "models actor alpha critics".split(): + assert hasattr(policy.module, attr) assert "models" in policy.optimizers assert "actor" in policy.optimizers From b36be36b7da8d8bb9ad2faa2e2fbdea1e846365d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 13:31:02 -0300 Subject: [PATCH 30/48] fix(networks): pass `name` argument to initialize_ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/modules/networks/actor/policy/deterministic.py | 2 +- raylab/modules/networks/actor/policy/stochastic.py | 2 +- raylab/modules/networks/critic/q_value.py | 4 ++-- raylab/modules/networks/mlp.py | 4 ++-- raylab/modules/networks/model/stochastic/single.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/raylab/modules/networks/actor/policy/deterministic.py b/raylab/modules/networks/actor/policy/deterministic.py index 0c5ca32a..7e6da6b5 100644 --- a/raylab/modules/networks/actor/policy/deterministic.py +++ b/raylab/modules/networks/actor/policy/deterministic.py @@ -125,7 +125,7 @@ def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. Args: - initializer_spec: Dictionary with mandatory `type` key corresponding + initializer_spec: Dictionary with mandatory `name` key corresponding to the initializer function name in `torch.nn.init` and optional keyword arguments. """ diff --git a/raylab/modules/networks/actor/policy/stochastic.py b/raylab/modules/networks/actor/policy/stochastic.py index 826d0f73..08b2e820 100644 --- a/raylab/modules/networks/actor/policy/stochastic.py +++ b/raylab/modules/networks/actor/policy/stochastic.py @@ -129,7 +129,7 @@ def initialize_parameters(self, initializer_spec: dict): """Initialize all Linear models in the encoder. Args: - initializer_spec: Dictionary with mandatory `type` key corresponding + initializer_spec: Dictionary with mandatory `name` key corresponding to the initializer function name in `torch.nn.init` and optional keyword arguments. """ diff --git a/raylab/modules/networks/critic/q_value.py b/raylab/modules/networks/critic/q_value.py index d722feaa..795e6606 100644 --- a/raylab/modules/networks/critic/q_value.py +++ b/raylab/modules/networks/critic/q_value.py @@ -57,7 +57,7 @@ def initialize_parameters(self, initializer_spec: dict): function. Args: - initializer_spec: Dictionary with mandatory `type` key corresponding + initializer_spec: Dictionary with mandatory `name` key corresponding to the initializer function name in `torch.nn.init` and optional keyword arguments. """ @@ -100,7 +100,7 @@ def initialize_parameters(self, initializer_spec: dict): """Initialize each Q estimator in the ensemble. Args: - initializer_spec: Dictionary with mandatory `type` key corresponding + initializer_spec: Dictionary with mandatory `name` key corresponding to the initializer function name in `torch.nn.init` and optional keyword arguments. """ diff --git a/raylab/modules/networks/mlp.py b/raylab/modules/networks/mlp.py index 17a58f0b..859060db 100644 --- a/raylab/modules/networks/mlp.py +++ b/raylab/modules/networks/mlp.py @@ -50,7 +50,7 @@ def initialize_parameters(self, initializer_spec: dict): function. Args: - initializer_spec: Dictionary with mandatory `type` key corresponding + initializer_spec: Dictionary with mandatory `name` key corresponding to the initializer function name in `torch.nn.init` and optional keyword arguments. """ @@ -99,7 +99,7 @@ def initialize_parameters(self, initializer_spec: dict): function. Args: - initializer_spec: Dictionary with mandatory `type` key corresponding + initializer_spec: Dictionary with mandatory `name` key corresponding to the initializer function name in `torch.nn.init` and optional keyword arguments. """ diff --git a/raylab/modules/networks/model/stochastic/single.py b/raylab/modules/networks/model/stochastic/single.py index dcbeb749..f4c97c26 100644 --- a/raylab/modules/networks/model/stochastic/single.py +++ b/raylab/modules/networks/model/stochastic/single.py @@ -169,7 +169,7 @@ def initialize_parameters(self, initializer_spec: dict): """Initialize all encoder parameters. Args: - initializer_spec: Dictionary with mandatory `type` key corresponding + initializer_spec: Dictionary with mandatory `name` key corresponding to the initializer function name in `torch.nn.init` and optional keyword arguments. """ From 0f646fefa2213d2b55030a5bd3dd8ddc9d71d966 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 13:54:00 -0300 Subject: [PATCH 31/48] chore(examples): use MBSAC in MAPO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- examples/MAPO/swingup.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/MAPO/swingup.py b/examples/MAPO/swingup.py index a2310707..3ca0af5c 100644 --- a/examples/MAPO/swingup.py +++ b/examples/MAPO/swingup.py @@ -8,14 +8,23 @@ def get_config(): "env_config": {"max_episode_steps": 500, "time_aware": False}, # === MAPOTorchPolicy === "module": { - "type": "ModelBasedSAC", + "type": "MBSAC", "model": { "ensemble_size": 1, + "residual": True, + "input_dependent_scale": True, + "network": {"units": (128, 128), "activation": "Swish"}, + }, + "actor": { + "encoder": {"units": (128, 128), "activation": "Swish"}, + "input_dependent_scale": True, + "initial_entropy_coeff": 0.05, + }, + "critic": { "encoder": {"units": (128, 128), "activation": "Swish"}, + "double_q": True, }, - "actor": {"encoder": {"units": (128, 128), "activation": "Swish"}}, - "critic": {"encoder": {"units": (128, 128), "activation": "Swish"}}, - "entropy": {"initial_alpha": 0.05}, + "initializer": {"name": "xavier_uniform"}, }, "losses": { # Gradient estimator for optimizing expectations. Possible types include From 2f103c6c9935372e79a14a7d37ea2b0f161aad0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 15:12:56 -0300 Subject: [PATCH 32/48] chore(examples): use MBSAC in MBPO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- examples/MBPO/swingup.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/MBPO/swingup.py b/examples/MBPO/swingup.py index 48f7136e..d5d0bb4e 100644 --- a/examples/MBPO/swingup.py +++ b/examples/MBPO/swingup.py @@ -8,18 +8,24 @@ def get_config(): "env_config": {"max_episode_steps": 500, "time_aware": False}, # === MBPOTorchPolicy === "module": { - "type": "ModelBasedSAC", + "type": "MBSAC", "model": { - "encoder": {"units": (128, 128), "activation": "Swish"}, "ensemble_size": 7, + "parallelize": True, + "residual": True, "input_dependent_scale": True, + "network": {"units": (128, 128), "activation": "Swish"}, }, "actor": { "encoder": {"units": (128, 128), "activation": "Swish"}, "input_dependent_scale": True, + "initial_entropy_coeff": 0.05, + }, + "critic": { + "encoder": {"units": (128, 128), "activation": "Swish"}, + "double_q": True, }, - "critic": {"encoder": {"units": (128, 128), "activation": "Swish"}}, - "entropy": {"initial_alpha": 0.05}, + "initializer": {"name": "xavier_uniform"}, }, "torch_optimizer": { "models": {"type": "Adam", "lr": 3e-4, "weight_decay": 0.0001}, @@ -63,6 +69,7 @@ def get_config(): "evaluation_num_episodes": 10, "timesteps_per_iteration": 1000, "num_cpus_for_driver": 4, + "compile_policy": True, # === RolloutWorker === "rollout_fragment_length": 25, "batch_mode": "truncate_episodes", From 3c5fa8fdaa0df6c36b0a8ab1996a2b41386fcf98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 15:50:03 -0300 Subject: [PATCH 33/48] refactor: move modules to raylab.policy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/mage/policy.py | 1 + raylab/losses/mle.py | 2 +- raylab/losses/policy_gradient.py | 2 +- raylab/losses/svg.py | 4 +- raylab/modules/catalog.py | 61 ------------------- raylab/{ => policy}/modules/__init__.py | 0 .../modules}/actor/__init__.py | 0 .../modules}/actor/deterministic.py | 0 .../modules}/actor/policy/__init__.py | 0 .../modules}/actor/policy/deterministic.py | 2 +- .../modules}/actor/policy/stochastic.py | 2 +- .../modules}/actor/stochastic.py | 0 raylab/policy/modules/catalog.py | 28 +++++++++ .../modules}/critic/__init__.py | 0 .../modules}/critic/action_value.py | 0 .../modules}/critic/q_value.py | 2 +- raylab/{ => policy}/modules/ddpg.py | 4 +- raylab/{ => policy}/modules/mbddpg.py | 4 +- raylab/{ => policy}/modules/mbsac.py | 4 +- .../modules}/model/__init__.py | 0 .../modules}/model/stochastic/__init__.py | 0 .../modules}/model/stochastic/builders.py | 0 .../modules}/model/stochastic/ensemble.py | 0 .../modules}/model/stochastic/single.py | 2 +- .../{ => policy}/modules/networks/__init__.py | 0 raylab/{ => policy}/modules/networks/mlp.py | 0 .../{ => policy}/modules/networks/resnet.py | 0 raylab/{ => policy}/modules/sac.py | 4 +- raylab/{ => policy}/modules/v0/__init__.py | 0 raylab/{ => policy}/modules/v0/abstract.py | 0 raylab/policy/modules/v0/catalog.py | 50 +++++++++++++++ raylab/{ => policy}/modules/v0/ddpg_module.py | 0 .../modules/v0/maxent_model_based.py | 0 .../modules/v0/mixins/__init__.py | 0 .../modules/v0/mixins/action_value_mixin.py | 0 .../v0/mixins/deterministic_actor_mixin.py | 0 .../v0/mixins/normalizing_flow_actor_mixin.py | 2 +- .../v0/mixins/normalizing_flow_model_mixin.py | 2 +- .../modules/v0/mixins/state_value_mixin.py | 0 .../v0/mixins/stochastic_actor_mixin.py | 0 .../v0/mixins/stochastic_model_mixin.py | 0 .../modules/v0/mixins/svg_model_mixin.py | 0 .../modules/v0/model_based_ddpg.py | 0 .../modules/v0/model_based_sac.py | 0 raylab/{ => policy}/modules/v0/naf_module.py | 0 raylab/{ => policy}/modules/v0/nfmbrl.py | 0 .../modules/v0/off_policy_nfac.py | 0 .../modules/v0/on_policy_actor_critic.py | 0 .../{ => policy}/modules/v0/on_policy_nfac.py | 0 raylab/{ => policy}/modules/v0/sac_module.py | 0 .../modules/v0/simple_model_based.py | 0 raylab/{ => policy}/modules/v0/svg_module.py | 0 .../modules/v0/svg_realnvp_actor.py | 0 .../{ => policy}/modules/v0/trpo_tang2018.py | 2 +- raylab/policy/torch_policy.py | 2 +- tests/losses/conftest.py | 10 +-- tests/{ => policy}/modules/__init__.py | 0 .../modules/actor}/__init__.py | 0 .../modules}/actor/conftest.py | 0 .../modules/actor/policy}/__init__.py | 0 .../actor/policy/test_deterministic.py | 4 +- .../modules}/actor/policy/test_stochastic.py | 6 +- .../modules}/actor/test_deterministic.py | 2 +- .../modules}/actor/test_stochastic.py | 2 +- .../networks => policy/modules}/conftest.py | 11 +++- .../modules/critic}/__init__.py | 0 .../modules}/critic/test_action_value.py | 2 +- .../modules/model}/__init__.py | 0 .../modules/model/stochastic}/__init__.py | 0 .../modules}/model/stochastic/conftest.py | 0 .../model/stochastic/test_ensemble.py | 8 +-- .../modules}/model/stochastic/test_single.py | 4 +- .../modules/networks}/__init__.py | 0 .../{ => policy}/modules/networks/test_mlp.py | 2 +- .../modules/networks/test_resnet.py | 2 +- tests/{ => policy}/modules/test_ddpg.py | 2 +- tests/{ => policy}/modules/test_sac.py | 2 +- tests/{ => policy}/modules/v0/__init__.py | 0 .../modules/v0}/conftest.py | 2 +- .../modules/v0/test_action_value_mixin.py | 2 +- .../v0/test_deterministic_actor_mixin.py | 2 +- .../modules/v0/test_naf_module.py | 2 +- .../v0/test_normalizing_flow_actor_mixin.py | 2 +- .../v0/test_normalizing_flow_model_mixin.py | 2 +- .../modules/v0/test_state_value_mixin.py | 2 +- .../modules/v0/test_stochastic_actor_mixin.py | 2 +- .../modules/v0/test_stochastic_model_mixin.py | 2 +- .../modules/v0/test_svg_module.py | 2 +- .../modules/v0/test_trpo_extensions.py | 2 +- tests/{ => policy}/modules/v0/utils.py | 0 .../nn/distributions/flows/test_couplings.py | 4 +- 91 files changed, 144 insertions(+), 119 deletions(-) delete mode 100644 raylab/modules/catalog.py rename raylab/{ => policy}/modules/__init__.py (100%) rename raylab/{modules/networks => policy/modules}/actor/__init__.py (100%) rename raylab/{modules/networks => policy/modules}/actor/deterministic.py (100%) rename raylab/{modules/networks => policy/modules}/actor/policy/__init__.py (100%) rename raylab/{modules/networks => policy/modules}/actor/policy/deterministic.py (98%) rename raylab/{modules/networks => policy/modules}/actor/policy/stochastic.py (99%) rename raylab/{modules/networks => policy/modules}/actor/stochastic.py (100%) create mode 100644 raylab/policy/modules/catalog.py rename raylab/{modules/networks => policy/modules}/critic/__init__.py (100%) rename raylab/{modules/networks => policy/modules}/critic/action_value.py (100%) rename raylab/{modules/networks => policy/modules}/critic/q_value.py (98%) rename raylab/{ => policy}/modules/ddpg.py (95%) rename raylab/{ => policy}/modules/mbddpg.py (95%) rename raylab/{ => policy}/modules/mbsac.py (94%) rename raylab/{modules/networks => policy/modules}/model/__init__.py (100%) rename raylab/{modules/networks => policy/modules}/model/stochastic/__init__.py (100%) rename raylab/{modules/networks => policy/modules}/model/stochastic/builders.py (100%) rename raylab/{modules/networks => policy/modules}/model/stochastic/ensemble.py (100%) rename raylab/{modules/networks => policy/modules}/model/stochastic/single.py (98%) rename raylab/{ => policy}/modules/networks/__init__.py (100%) rename raylab/{ => policy}/modules/networks/mlp.py (100%) rename raylab/{ => policy}/modules/networks/resnet.py (100%) rename raylab/{ => policy}/modules/sac.py (95%) rename raylab/{ => policy}/modules/v0/__init__.py (100%) rename raylab/{ => policy}/modules/v0/abstract.py (100%) create mode 100644 raylab/policy/modules/v0/catalog.py rename raylab/{ => policy}/modules/v0/ddpg_module.py (100%) rename raylab/{ => policy}/modules/v0/maxent_model_based.py (100%) rename raylab/{ => policy}/modules/v0/mixins/__init__.py (100%) rename raylab/{ => policy}/modules/v0/mixins/action_value_mixin.py (100%) rename raylab/{ => policy}/modules/v0/mixins/deterministic_actor_mixin.py (100%) rename raylab/{ => policy}/modules/v0/mixins/normalizing_flow_actor_mixin.py (98%) rename raylab/{ => policy}/modules/v0/mixins/normalizing_flow_model_mixin.py (99%) rename raylab/{ => policy}/modules/v0/mixins/state_value_mixin.py (100%) rename raylab/{ => policy}/modules/v0/mixins/stochastic_actor_mixin.py (100%) rename raylab/{ => policy}/modules/v0/mixins/stochastic_model_mixin.py (100%) rename raylab/{ => policy}/modules/v0/mixins/svg_model_mixin.py (100%) rename raylab/{ => policy}/modules/v0/model_based_ddpg.py (100%) rename raylab/{ => policy}/modules/v0/model_based_sac.py (100%) rename raylab/{ => policy}/modules/v0/naf_module.py (100%) rename raylab/{ => policy}/modules/v0/nfmbrl.py (100%) rename raylab/{ => policy}/modules/v0/off_policy_nfac.py (100%) rename raylab/{ => policy}/modules/v0/on_policy_actor_critic.py (100%) rename raylab/{ => policy}/modules/v0/on_policy_nfac.py (100%) rename raylab/{ => policy}/modules/v0/sac_module.py (100%) rename raylab/{ => policy}/modules/v0/simple_model_based.py (100%) rename raylab/{ => policy}/modules/v0/svg_module.py (100%) rename raylab/{ => policy}/modules/v0/svg_realnvp_actor.py (100%) rename raylab/{ => policy}/modules/v0/trpo_tang2018.py (99%) rename tests/{ => policy}/modules/__init__.py (100%) rename tests/{modules/networks => policy/modules/actor}/__init__.py (100%) rename tests/{modules/networks => policy/modules}/actor/conftest.py (100%) rename tests/{modules/networks/actor => policy/modules/actor/policy}/__init__.py (100%) rename tests/{modules/networks => policy/modules}/actor/policy/test_deterministic.py (90%) rename tests/{modules/networks => policy/modules}/actor/policy/test_stochastic.py (93%) rename tests/{modules/networks => policy/modules}/actor/test_deterministic.py (96%) rename tests/{modules/networks => policy/modules}/actor/test_stochastic.py (94%) rename tests/{modules/networks => policy/modules}/conftest.py (50%) rename tests/{modules/networks/actor/policy => policy/modules/critic}/__init__.py (100%) rename tests/{modules/networks => policy/modules}/critic/test_action_value.py (95%) rename tests/{modules/networks/critic => policy/modules/model}/__init__.py (100%) rename tests/{modules/networks/model => policy/modules/model/stochastic}/__init__.py (100%) rename tests/{modules/networks => policy/modules}/model/stochastic/conftest.py (100%) rename tests/{modules/networks => policy/modules}/model/stochastic/test_ensemble.py (82%) rename tests/{modules/networks => policy/modules}/model/stochastic/test_single.py (94%) rename tests/{modules/networks/model/stochastic => policy/modules/networks}/__init__.py (100%) rename tests/{ => policy}/modules/networks/test_mlp.py (90%) rename tests/{ => policy}/modules/networks/test_resnet.py (89%) rename tests/{ => policy}/modules/test_ddpg.py (94%) rename tests/{ => policy}/modules/test_sac.py (94%) rename tests/{ => policy}/modules/v0/__init__.py (100%) rename tests/{modules => policy/modules/v0}/conftest.py (92%) rename tests/{ => policy}/modules/v0/test_action_value_mixin.py (96%) rename tests/{ => policy}/modules/v0/test_deterministic_actor_mixin.py (98%) rename tests/{ => policy}/modules/v0/test_naf_module.py (91%) rename tests/{ => policy}/modules/v0/test_normalizing_flow_actor_mixin.py (98%) rename tests/{ => policy}/modules/v0/test_normalizing_flow_model_mixin.py (98%) rename tests/{ => policy}/modules/v0/test_state_value_mixin.py (96%) rename tests/{ => policy}/modules/v0/test_stochastic_actor_mixin.py (98%) rename tests/{ => policy}/modules/v0/test_stochastic_model_mixin.py (98%) rename tests/{ => policy}/modules/v0/test_svg_module.py (95%) rename tests/{ => policy}/modules/v0/test_trpo_extensions.py (97%) rename tests/{ => policy}/modules/v0/utils.py (100%) diff --git a/raylab/agents/mage/policy.py b/raylab/agents/mage/policy.py index 8672fce1..67368731 100644 --- a/raylab/agents/mage/policy.py +++ b/raylab/agents/mage/policy.py @@ -37,6 +37,7 @@ def __init__(self, observation_space, action_space, config): @staticmethod def get_default_config(): + # pylint:disable=cyclic-import from raylab.agents.mage import DEFAULT_CONFIG return DEFAULT_CONFIG diff --git a/raylab/losses/mle.py b/raylab/losses/mle.py index 585393d4..f24d0114 100644 --- a/raylab/losses/mle.py +++ b/raylab/losses/mle.py @@ -8,7 +8,7 @@ from ray.rllib import SampleBatch from torch import Tensor -from raylab.modules.v0.mixins.stochastic_model_mixin import StochasticModel +from raylab.policy.modules.v0.mixins.stochastic_model_mixin import StochasticModel from raylab.utils.dictionaries import get_keys from .abstract import Loss diff --git a/raylab/losses/policy_gradient.py b/raylab/losses/policy_gradient.py index 620a0d68..548529a0 100644 --- a/raylab/losses/policy_gradient.py +++ b/raylab/losses/policy_gradient.py @@ -8,7 +8,7 @@ from ray.rllib import SampleBatch from torch import Tensor -from raylab.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.policy.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy from raylab.utils.annotations import DetPolicy from raylab.utils.annotations import DynamicsFn from raylab.utils.annotations import RewardFn diff --git a/raylab/losses/svg.py b/raylab/losses/svg.py index cbd97f1b..dae05b4b 100644 --- a/raylab/losses/svg.py +++ b/raylab/losses/svg.py @@ -10,8 +10,8 @@ from ray.rllib.utils import override from torch import Tensor -from raylab.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy -from raylab.modules.v0.mixins.stochastic_model_mixin import StochasticModel +from raylab.policy.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.policy.modules.v0.mixins.stochastic_model_mixin import StochasticModel from raylab.utils.annotations import RewardFn from raylab.utils.annotations import StateValue from raylab.utils.dictionaries import get_keys diff --git a/raylab/modules/catalog.py b/raylab/modules/catalog.py deleted file mode 100644 index 4953cff3..00000000 --- a/raylab/modules/catalog.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Registry of modules for PyTorch policies.""" -import torch.nn as nn -from gym.spaces import Space - -from .ddpg import DDPG -from .mbddpg import MBDDPG -from .mbsac import MBSAC -from .sac import SAC -from .v0.ddpg_module import DDPGModule -from .v0.maxent_model_based import MaxEntModelBased -from .v0.model_based_ddpg import ModelBasedDDPG -from .v0.model_based_sac import ModelBasedSAC -from .v0.naf_module import NAFModule -from .v0.nfmbrl import NFMBRL -from .v0.off_policy_nfac import OffPolicyNFAC -from .v0.on_policy_actor_critic import OnPolicyActorCritic -from .v0.on_policy_nfac import OnPolicyNFAC -from .v0.sac_module import SACModule -from .v0.simple_model_based import SimpleModelBased -from .v0.svg_module import SVGModule -from .v0.svg_realnvp_actor import SVGRealNVPActor -from .v0.trpo_tang2018 import TRPOTang2018 - -MODULESv0 = { - cls.__name__: cls - for cls in ( - NAFModule, - DDPGModule, - SACModule, - SimpleModelBased, - SVGModule, - MaxEntModelBased, - ModelBasedDDPG, - ModelBasedSAC, - NFMBRL, - OnPolicyActorCritic, - OnPolicyNFAC, - OffPolicyNFAC, - TRPOTang2018, - SVGRealNVPActor, - ) -} - -MODULESv1 = {cls.__name__: cls for cls in (DDPG, MBDDPG, MBSAC, SAC)} - - -def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: - """Retrieve and construct module of given name. - - Args: - obs_space: Observation space - action_space: Action space - config: Configurations for module construction and initialization - """ - type_ = config.pop("type") - if type_ in MODULESv0: - return MODULESv0[type_](obs_space, action_space, config) - - cls = MODULESv1[type_] - spec = cls.spec_cls.from_dict(config) - return cls(obs_space, action_space, spec) diff --git a/raylab/modules/__init__.py b/raylab/policy/modules/__init__.py similarity index 100% rename from raylab/modules/__init__.py rename to raylab/policy/modules/__init__.py diff --git a/raylab/modules/networks/actor/__init__.py b/raylab/policy/modules/actor/__init__.py similarity index 100% rename from raylab/modules/networks/actor/__init__.py rename to raylab/policy/modules/actor/__init__.py diff --git a/raylab/modules/networks/actor/deterministic.py b/raylab/policy/modules/actor/deterministic.py similarity index 100% rename from raylab/modules/networks/actor/deterministic.py rename to raylab/policy/modules/actor/deterministic.py diff --git a/raylab/modules/networks/actor/policy/__init__.py b/raylab/policy/modules/actor/policy/__init__.py similarity index 100% rename from raylab/modules/networks/actor/policy/__init__.py rename to raylab/policy/modules/actor/policy/__init__.py diff --git a/raylab/modules/networks/actor/policy/deterministic.py b/raylab/policy/modules/actor/policy/deterministic.py similarity index 98% rename from raylab/modules/networks/actor/policy/deterministic.py rename to raylab/policy/modules/actor/policy/deterministic.py index 7e6da6b5..457b5f53 100644 --- a/raylab/modules/networks/actor/policy/deterministic.py +++ b/raylab/policy/modules/actor/policy/deterministic.py @@ -8,7 +8,7 @@ from torch import Tensor import raylab.pytorch.nn as nnx -from raylab.modules.networks.mlp import StateMLP +from raylab.policy.modules.networks.mlp import StateMLP class DeterministicPolicy(nn.Module): diff --git a/raylab/modules/networks/actor/policy/stochastic.py b/raylab/policy/modules/actor/policy/stochastic.py similarity index 99% rename from raylab/modules/networks/actor/policy/stochastic.py rename to raylab/policy/modules/actor/policy/stochastic.py index 08b2e820..944d9983 100644 --- a/raylab/modules/networks/actor/policy/stochastic.py +++ b/raylab/policy/modules/actor/policy/stochastic.py @@ -9,7 +9,7 @@ import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd -from raylab.modules.networks.mlp import StateMLP +from raylab.policy.modules.networks.mlp import StateMLP class StochasticPolicy(nn.Module): diff --git a/raylab/modules/networks/actor/stochastic.py b/raylab/policy/modules/actor/stochastic.py similarity index 100% rename from raylab/modules/networks/actor/stochastic.py rename to raylab/policy/modules/actor/stochastic.py diff --git a/raylab/policy/modules/catalog.py b/raylab/policy/modules/catalog.py new file mode 100644 index 00000000..7975035b --- /dev/null +++ b/raylab/policy/modules/catalog.py @@ -0,0 +1,28 @@ +"""Registry of modules for PyTorch policies.""" +import torch.nn as nn +from gym.spaces import Space + +from .ddpg import DDPG +from .mbddpg import MBDDPG +from .mbsac import MBSAC +from .sac import SAC +from .v0.catalog import get_module as get_v0_module + +MODULES = {cls.__name__: cls for cls in (DDPG, MBDDPG, MBSAC, SAC)} + + +def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: + """Retrieve and construct module of given name. + + Args: + obs_space: Observation space + action_space: Action space + config: Configurations for module construction and initialization + """ + type_ = config.pop("type") + if type_ not in MODULES: + return get_v0_module(obs_space, action_space, {"type": type_, **config}) + + cls = MODULES[type_] + spec = cls.spec_cls.from_dict(config) + return cls(obs_space, action_space, spec) diff --git a/raylab/modules/networks/critic/__init__.py b/raylab/policy/modules/critic/__init__.py similarity index 100% rename from raylab/modules/networks/critic/__init__.py rename to raylab/policy/modules/critic/__init__.py diff --git a/raylab/modules/networks/critic/action_value.py b/raylab/policy/modules/critic/action_value.py similarity index 100% rename from raylab/modules/networks/critic/action_value.py rename to raylab/policy/modules/critic/action_value.py diff --git a/raylab/modules/networks/critic/q_value.py b/raylab/policy/modules/critic/q_value.py similarity index 98% rename from raylab/modules/networks/critic/q_value.py rename to raylab/policy/modules/critic/q_value.py index 795e6606..1059f472 100644 --- a/raylab/modules/networks/critic/q_value.py +++ b/raylab/policy/modules/critic/q_value.py @@ -4,7 +4,7 @@ from gym.spaces import Box from torch import Tensor -from raylab.modules.networks.mlp import StateActionMLP +from raylab.policy.modules.networks.mlp import StateActionMLP MLPSpec = StateActionMLP.spec_cls diff --git a/raylab/modules/ddpg.py b/raylab/policy/modules/ddpg.py similarity index 95% rename from raylab/modules/ddpg.py rename to raylab/policy/modules/ddpg.py index 3ac76b42..6c114be4 100644 --- a/raylab/modules/ddpg.py +++ b/raylab/policy/modules/ddpg.py @@ -6,8 +6,8 @@ from dataclasses_json import DataClassJsonMixin from gym.spaces import Box -from .networks.actor.deterministic import DeterministicActor -from .networks.critic.action_value import ActionValueCritic +from .actor.deterministic import DeterministicActor +from .critic.action_value import ActionValueCritic ActorSpec = DeterministicActor.spec_cls CriticSpec = ActionValueCritic.spec_cls diff --git a/raylab/modules/mbddpg.py b/raylab/policy/modules/mbddpg.py similarity index 95% rename from raylab/modules/mbddpg.py rename to raylab/policy/modules/mbddpg.py index 09174f0d..931dc338 100644 --- a/raylab/modules/mbddpg.py +++ b/raylab/policy/modules/mbddpg.py @@ -6,8 +6,8 @@ from .ddpg import DDPG from .ddpg import DDPGSpec -from .networks.model.stochastic import build_ensemble -from .networks.model.stochastic import EnsembleSpec +from .model.stochastic import build_ensemble +from .model.stochastic import EnsembleSpec @dataclass diff --git a/raylab/modules/mbsac.py b/raylab/policy/modules/mbsac.py similarity index 94% rename from raylab/modules/mbsac.py rename to raylab/policy/modules/mbsac.py index 5cb620a0..dc7c897c 100644 --- a/raylab/modules/mbsac.py +++ b/raylab/policy/modules/mbsac.py @@ -4,8 +4,8 @@ from gym.spaces import Box -from .networks.model.stochastic import build_ensemble -from .networks.model.stochastic import EnsembleSpec +from .model.stochastic import build_ensemble +from .model.stochastic import EnsembleSpec from .sac import SAC from .sac import SACSpec diff --git a/raylab/modules/networks/model/__init__.py b/raylab/policy/modules/model/__init__.py similarity index 100% rename from raylab/modules/networks/model/__init__.py rename to raylab/policy/modules/model/__init__.py diff --git a/raylab/modules/networks/model/stochastic/__init__.py b/raylab/policy/modules/model/stochastic/__init__.py similarity index 100% rename from raylab/modules/networks/model/stochastic/__init__.py rename to raylab/policy/modules/model/stochastic/__init__.py diff --git a/raylab/modules/networks/model/stochastic/builders.py b/raylab/policy/modules/model/stochastic/builders.py similarity index 100% rename from raylab/modules/networks/model/stochastic/builders.py rename to raylab/policy/modules/model/stochastic/builders.py diff --git a/raylab/modules/networks/model/stochastic/ensemble.py b/raylab/policy/modules/model/stochastic/ensemble.py similarity index 100% rename from raylab/modules/networks/model/stochastic/ensemble.py rename to raylab/policy/modules/model/stochastic/ensemble.py diff --git a/raylab/modules/networks/model/stochastic/single.py b/raylab/policy/modules/model/stochastic/single.py similarity index 98% rename from raylab/modules/networks/model/stochastic/single.py rename to raylab/policy/modules/model/stochastic/single.py index f4c97c26..0b214159 100644 --- a/raylab/modules/networks/model/stochastic/single.py +++ b/raylab/policy/modules/model/stochastic/single.py @@ -7,7 +7,7 @@ import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd -from raylab.modules.networks.mlp import StateActionMLP +from raylab.policy.modules.networks.mlp import StateActionMLP class StochasticModel(nn.Module): diff --git a/raylab/modules/networks/__init__.py b/raylab/policy/modules/networks/__init__.py similarity index 100% rename from raylab/modules/networks/__init__.py rename to raylab/policy/modules/networks/__init__.py diff --git a/raylab/modules/networks/mlp.py b/raylab/policy/modules/networks/mlp.py similarity index 100% rename from raylab/modules/networks/mlp.py rename to raylab/policy/modules/networks/mlp.py diff --git a/raylab/modules/networks/resnet.py b/raylab/policy/modules/networks/resnet.py similarity index 100% rename from raylab/modules/networks/resnet.py rename to raylab/policy/modules/networks/resnet.py diff --git a/raylab/modules/sac.py b/raylab/policy/modules/sac.py similarity index 95% rename from raylab/modules/sac.py rename to raylab/policy/modules/sac.py index bbf1d1f7..02e179f7 100644 --- a/raylab/modules/sac.py +++ b/raylab/policy/modules/sac.py @@ -6,8 +6,8 @@ from dataclasses_json import DataClassJsonMixin from gym.spaces import Box -from .networks.actor.stochastic import StochasticActor -from .networks.critic.action_value import ActionValueCritic +from .actor.stochastic import StochasticActor +from .critic.action_value import ActionValueCritic ActorSpec = StochasticActor.spec_cls CriticSpec = ActionValueCritic.spec_cls diff --git a/raylab/modules/v0/__init__.py b/raylab/policy/modules/v0/__init__.py similarity index 100% rename from raylab/modules/v0/__init__.py rename to raylab/policy/modules/v0/__init__.py diff --git a/raylab/modules/v0/abstract.py b/raylab/policy/modules/v0/abstract.py similarity index 100% rename from raylab/modules/v0/abstract.py rename to raylab/policy/modules/v0/abstract.py diff --git a/raylab/policy/modules/v0/catalog.py b/raylab/policy/modules/v0/catalog.py new file mode 100644 index 00000000..e72bcb8c --- /dev/null +++ b/raylab/policy/modules/v0/catalog.py @@ -0,0 +1,50 @@ +"""Registry of old modules for PyTorch policies.""" +import torch.nn as nn +from gym.spaces import Space + +from .ddpg_module import DDPGModule +from .maxent_model_based import MaxEntModelBased +from .model_based_ddpg import ModelBasedDDPG +from .model_based_sac import ModelBasedSAC +from .naf_module import NAFModule +from .nfmbrl import NFMBRL +from .off_policy_nfac import OffPolicyNFAC +from .on_policy_actor_critic import OnPolicyActorCritic +from .on_policy_nfac import OnPolicyNFAC +from .sac_module import SACModule +from .simple_model_based import SimpleModelBased +from .svg_module import SVGModule +from .svg_realnvp_actor import SVGRealNVPActor +from .trpo_tang2018 import TRPOTang2018 + +MODULES = { + cls.__name__: cls + for cls in ( + NAFModule, + DDPGModule, + SACModule, + SimpleModelBased, + SVGModule, + MaxEntModelBased, + ModelBasedDDPG, + ModelBasedSAC, + NFMBRL, + OnPolicyActorCritic, + OnPolicyNFAC, + OffPolicyNFAC, + TRPOTang2018, + SVGRealNVPActor, + ) +} + + +def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: + """Retrieve and construct module of given name. + + Args: + obs_space: Observation space + action_space: Action space + config: Configurations for module construction and initialization + """ + type_ = config.pop("type") + return MODULES[type_](obs_space, action_space, config) diff --git a/raylab/modules/v0/ddpg_module.py b/raylab/policy/modules/v0/ddpg_module.py similarity index 100% rename from raylab/modules/v0/ddpg_module.py rename to raylab/policy/modules/v0/ddpg_module.py diff --git a/raylab/modules/v0/maxent_model_based.py b/raylab/policy/modules/v0/maxent_model_based.py similarity index 100% rename from raylab/modules/v0/maxent_model_based.py rename to raylab/policy/modules/v0/maxent_model_based.py diff --git a/raylab/modules/v0/mixins/__init__.py b/raylab/policy/modules/v0/mixins/__init__.py similarity index 100% rename from raylab/modules/v0/mixins/__init__.py rename to raylab/policy/modules/v0/mixins/__init__.py diff --git a/raylab/modules/v0/mixins/action_value_mixin.py b/raylab/policy/modules/v0/mixins/action_value_mixin.py similarity index 100% rename from raylab/modules/v0/mixins/action_value_mixin.py rename to raylab/policy/modules/v0/mixins/action_value_mixin.py diff --git a/raylab/modules/v0/mixins/deterministic_actor_mixin.py b/raylab/policy/modules/v0/mixins/deterministic_actor_mixin.py similarity index 100% rename from raylab/modules/v0/mixins/deterministic_actor_mixin.py rename to raylab/policy/modules/v0/mixins/deterministic_actor_mixin.py diff --git a/raylab/modules/v0/mixins/normalizing_flow_actor_mixin.py b/raylab/policy/modules/v0/mixins/normalizing_flow_actor_mixin.py similarity index 98% rename from raylab/modules/v0/mixins/normalizing_flow_actor_mixin.py rename to raylab/policy/modules/v0/mixins/normalizing_flow_actor_mixin.py index 5c0bce2c..88a9a610 100644 --- a/raylab/modules/v0/mixins/normalizing_flow_actor_mixin.py +++ b/raylab/policy/modules/v0/mixins/normalizing_flow_actor_mixin.py @@ -6,7 +6,7 @@ import torch.nn as nn from ray.rllib.utils import override -import raylab.modules.networks as networks +import raylab.policy.modules.networks as networks import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd from raylab.utils.dictionaries import deep_merge diff --git a/raylab/modules/v0/mixins/normalizing_flow_model_mixin.py b/raylab/policy/modules/v0/mixins/normalizing_flow_model_mixin.py similarity index 99% rename from raylab/modules/v0/mixins/normalizing_flow_model_mixin.py rename to raylab/policy/modules/v0/mixins/normalizing_flow_model_mixin.py index 62015ea7..031cc061 100644 --- a/raylab/modules/v0/mixins/normalizing_flow_model_mixin.py +++ b/raylab/policy/modules/v0/mixins/normalizing_flow_model_mixin.py @@ -5,7 +5,7 @@ import torch.nn as nn from ray.rllib.utils import override -import raylab.modules.networks as networks +import raylab.policy.modules.networks as networks import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd from raylab.utils.dictionaries import deep_merge diff --git a/raylab/modules/v0/mixins/state_value_mixin.py b/raylab/policy/modules/v0/mixins/state_value_mixin.py similarity index 100% rename from raylab/modules/v0/mixins/state_value_mixin.py rename to raylab/policy/modules/v0/mixins/state_value_mixin.py diff --git a/raylab/modules/v0/mixins/stochastic_actor_mixin.py b/raylab/policy/modules/v0/mixins/stochastic_actor_mixin.py similarity index 100% rename from raylab/modules/v0/mixins/stochastic_actor_mixin.py rename to raylab/policy/modules/v0/mixins/stochastic_actor_mixin.py diff --git a/raylab/modules/v0/mixins/stochastic_model_mixin.py b/raylab/policy/modules/v0/mixins/stochastic_model_mixin.py similarity index 100% rename from raylab/modules/v0/mixins/stochastic_model_mixin.py rename to raylab/policy/modules/v0/mixins/stochastic_model_mixin.py diff --git a/raylab/modules/v0/mixins/svg_model_mixin.py b/raylab/policy/modules/v0/mixins/svg_model_mixin.py similarity index 100% rename from raylab/modules/v0/mixins/svg_model_mixin.py rename to raylab/policy/modules/v0/mixins/svg_model_mixin.py diff --git a/raylab/modules/v0/model_based_ddpg.py b/raylab/policy/modules/v0/model_based_ddpg.py similarity index 100% rename from raylab/modules/v0/model_based_ddpg.py rename to raylab/policy/modules/v0/model_based_ddpg.py diff --git a/raylab/modules/v0/model_based_sac.py b/raylab/policy/modules/v0/model_based_sac.py similarity index 100% rename from raylab/modules/v0/model_based_sac.py rename to raylab/policy/modules/v0/model_based_sac.py diff --git a/raylab/modules/v0/naf_module.py b/raylab/policy/modules/v0/naf_module.py similarity index 100% rename from raylab/modules/v0/naf_module.py rename to raylab/policy/modules/v0/naf_module.py diff --git a/raylab/modules/v0/nfmbrl.py b/raylab/policy/modules/v0/nfmbrl.py similarity index 100% rename from raylab/modules/v0/nfmbrl.py rename to raylab/policy/modules/v0/nfmbrl.py diff --git a/raylab/modules/v0/off_policy_nfac.py b/raylab/policy/modules/v0/off_policy_nfac.py similarity index 100% rename from raylab/modules/v0/off_policy_nfac.py rename to raylab/policy/modules/v0/off_policy_nfac.py diff --git a/raylab/modules/v0/on_policy_actor_critic.py b/raylab/policy/modules/v0/on_policy_actor_critic.py similarity index 100% rename from raylab/modules/v0/on_policy_actor_critic.py rename to raylab/policy/modules/v0/on_policy_actor_critic.py diff --git a/raylab/modules/v0/on_policy_nfac.py b/raylab/policy/modules/v0/on_policy_nfac.py similarity index 100% rename from raylab/modules/v0/on_policy_nfac.py rename to raylab/policy/modules/v0/on_policy_nfac.py diff --git a/raylab/modules/v0/sac_module.py b/raylab/policy/modules/v0/sac_module.py similarity index 100% rename from raylab/modules/v0/sac_module.py rename to raylab/policy/modules/v0/sac_module.py diff --git a/raylab/modules/v0/simple_model_based.py b/raylab/policy/modules/v0/simple_model_based.py similarity index 100% rename from raylab/modules/v0/simple_model_based.py rename to raylab/policy/modules/v0/simple_model_based.py diff --git a/raylab/modules/v0/svg_module.py b/raylab/policy/modules/v0/svg_module.py similarity index 100% rename from raylab/modules/v0/svg_module.py rename to raylab/policy/modules/v0/svg_module.py diff --git a/raylab/modules/v0/svg_realnvp_actor.py b/raylab/policy/modules/v0/svg_realnvp_actor.py similarity index 100% rename from raylab/modules/v0/svg_realnvp_actor.py rename to raylab/policy/modules/v0/svg_realnvp_actor.py diff --git a/raylab/modules/v0/trpo_tang2018.py b/raylab/policy/modules/v0/trpo_tang2018.py similarity index 99% rename from raylab/modules/v0/trpo_tang2018.py rename to raylab/policy/modules/v0/trpo_tang2018.py index d76cef70..e83f32c0 100644 --- a/raylab/modules/v0/trpo_tang2018.py +++ b/raylab/policy/modules/v0/trpo_tang2018.py @@ -7,7 +7,7 @@ from ray.rllib.utils import merge_dicts from ray.rllib.utils import override -import raylab.modules.networks as networks +import raylab.policy.modules.networks as networks from raylab.pytorch.nn import FullyConnected from raylab.pytorch.nn.distributions import flows from raylab.pytorch.nn.distributions import Independent diff --git a/raylab/policy/torch_policy.py b/raylab/policy/torch_policy.py index 055efb36..58bc3fa5 100644 --- a/raylab/policy/torch_policy.py +++ b/raylab/policy/torch_policy.py @@ -22,11 +22,11 @@ from torch.optim import Optimizer from raylab.agents import Trainer -from raylab.modules.catalog import get_module from raylab.pytorch.utils import convert_to_tensor from raylab.utils.dictionaries import deep_merge from .action_dist import WrapModuleDist +from .modules.catalog import get_module from .optimizer_collection import OptimizerCollection diff --git a/tests/losses/conftest.py b/tests/losses/conftest.py index c740c662..11f0de24 100644 --- a/tests/losses/conftest.py +++ b/tests/losses/conftest.py @@ -5,10 +5,12 @@ import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd -from raylab.modules.v0.mixins.action_value_mixin import ActionValueFunction -from raylab.modules.v0.mixins.deterministic_actor_mixin import DeterministicPolicy -from raylab.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy -from raylab.modules.v0.mixins.stochastic_model_mixin import StochasticModelMixin +from raylab.policy.modules.v0.mixins.action_value_mixin import ActionValueFunction +from raylab.policy.modules.v0.mixins.deterministic_actor_mixin import ( + DeterministicPolicy, +) +from raylab.policy.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.policy.modules.v0.mixins.stochastic_model_mixin import StochasticModelMixin from raylab.utils.debug import fake_batch diff --git a/tests/modules/__init__.py b/tests/policy/modules/__init__.py similarity index 100% rename from tests/modules/__init__.py rename to tests/policy/modules/__init__.py diff --git a/tests/modules/networks/__init__.py b/tests/policy/modules/actor/__init__.py similarity index 100% rename from tests/modules/networks/__init__.py rename to tests/policy/modules/actor/__init__.py diff --git a/tests/modules/networks/actor/conftest.py b/tests/policy/modules/actor/conftest.py similarity index 100% rename from tests/modules/networks/actor/conftest.py rename to tests/policy/modules/actor/conftest.py diff --git a/tests/modules/networks/actor/__init__.py b/tests/policy/modules/actor/policy/__init__.py similarity index 100% rename from tests/modules/networks/actor/__init__.py rename to tests/policy/modules/actor/policy/__init__.py diff --git a/tests/modules/networks/actor/policy/test_deterministic.py b/tests/policy/modules/actor/policy/test_deterministic.py similarity index 90% rename from tests/modules/networks/actor/policy/test_deterministic.py rename to tests/policy/modules/actor/policy/test_deterministic.py index 77dacbea..528e1ef1 100644 --- a/tests/modules/networks/actor/policy/test_deterministic.py +++ b/tests/policy/modules/actor/policy/test_deterministic.py @@ -6,9 +6,7 @@ @pytest.fixture(scope="module") def module_cls(): - from raylab.modules.networks.actor.policy.deterministic import ( - MLPDeterministicPolicy, - ) + from raylab.policy.modules.actor.policy.deterministic import MLPDeterministicPolicy return MLPDeterministicPolicy diff --git a/tests/modules/networks/actor/policy/test_stochastic.py b/tests/policy/modules/actor/policy/test_stochastic.py similarity index 93% rename from tests/modules/networks/actor/policy/test_stochastic.py rename to tests/policy/modules/actor/policy/test_stochastic.py index 69bced1b..1cc8da67 100644 --- a/tests/modules/networks/actor/policy/test_stochastic.py +++ b/tests/policy/modules/actor/policy/test_stochastic.py @@ -6,21 +6,21 @@ @pytest.fixture(scope="module") def base_cls(): - from raylab.modules.networks.actor.policy.stochastic import MLPStochasticPolicy + from raylab.policy.modules.actor.policy.stochastic import MLPStochasticPolicy return MLPStochasticPolicy @pytest.fixture(scope="module") def cont_cls(): - from raylab.modules.networks.actor.policy.stochastic import MLPContinuousPolicy + from raylab.policy.modules.actor.policy.stochastic import MLPContinuousPolicy return MLPContinuousPolicy @pytest.fixture(scope="module") def disc_cls(): - from raylab.modules.networks.actor.policy.stochastic import MLPDiscretePolicy + from raylab.policy.modules.actor.policy.stochastic import MLPDiscretePolicy return MLPDiscretePolicy diff --git a/tests/modules/networks/actor/test_deterministic.py b/tests/policy/modules/actor/test_deterministic.py similarity index 96% rename from tests/modules/networks/actor/test_deterministic.py rename to tests/policy/modules/actor/test_deterministic.py index 3e6ec146..33959308 100644 --- a/tests/modules/networks/actor/test_deterministic.py +++ b/tests/policy/modules/actor/test_deterministic.py @@ -16,7 +16,7 @@ def batch(cont_batch): @pytest.fixture(scope="module") def module_cls(): - from raylab.modules.networks.actor.deterministic import DeterministicActor + from raylab.policy.modules.actor.deterministic import DeterministicActor return DeterministicActor diff --git a/tests/modules/networks/actor/test_stochastic.py b/tests/policy/modules/actor/test_stochastic.py similarity index 94% rename from tests/modules/networks/actor/test_stochastic.py rename to tests/policy/modules/actor/test_stochastic.py index 6a0c3455..8ebf382b 100644 --- a/tests/modules/networks/actor/test_stochastic.py +++ b/tests/policy/modules/actor/test_stochastic.py @@ -5,7 +5,7 @@ @pytest.fixture(scope="module") def module_cls(): - from raylab.modules.networks.actor.stochastic import StochasticActor + from raylab.policy.modules.actor.stochastic import StochasticActor return StochasticActor diff --git a/tests/modules/networks/conftest.py b/tests/policy/modules/conftest.py similarity index 50% rename from tests/modules/networks/conftest.py rename to tests/policy/modules/conftest.py index 23005dc6..3156cba9 100644 --- a/tests/modules/networks/conftest.py +++ b/tests/policy/modules/conftest.py @@ -1,8 +1,17 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access +# pylint:disable=missing-docstring,redefined-outer-name,protected-access import pytest import torch +@pytest.fixture( + params=(pytest.param(True, marks=pytest.mark.slow), False), + ids=("TorchScript", "Eager"), + scope="module", +) +def torch_script(request): + return request.param + + @pytest.fixture(scope="module") def batch(obs_space, action_space): from raylab.utils.debug import fake_batch diff --git a/tests/modules/networks/actor/policy/__init__.py b/tests/policy/modules/critic/__init__.py similarity index 100% rename from tests/modules/networks/actor/policy/__init__.py rename to tests/policy/modules/critic/__init__.py diff --git a/tests/modules/networks/critic/test_action_value.py b/tests/policy/modules/critic/test_action_value.py similarity index 95% rename from tests/modules/networks/critic/test_action_value.py rename to tests/policy/modules/critic/test_action_value.py index 4fd5f89b..3a2f7976 100644 --- a/tests/modules/networks/critic/test_action_value.py +++ b/tests/policy/modules/critic/test_action_value.py @@ -6,7 +6,7 @@ @pytest.fixture(scope="module") def module_cls(): - from raylab.modules.networks.critic.action_value import ActionValueCritic + from raylab.policy.modules.critic.action_value import ActionValueCritic return ActionValueCritic diff --git a/tests/modules/networks/critic/__init__.py b/tests/policy/modules/model/__init__.py similarity index 100% rename from tests/modules/networks/critic/__init__.py rename to tests/policy/modules/model/__init__.py diff --git a/tests/modules/networks/model/__init__.py b/tests/policy/modules/model/stochastic/__init__.py similarity index 100% rename from tests/modules/networks/model/__init__.py rename to tests/policy/modules/model/stochastic/__init__.py diff --git a/tests/modules/networks/model/stochastic/conftest.py b/tests/policy/modules/model/stochastic/conftest.py similarity index 100% rename from tests/modules/networks/model/stochastic/conftest.py rename to tests/policy/modules/model/stochastic/conftest.py diff --git a/tests/modules/networks/model/stochastic/test_ensemble.py b/tests/policy/modules/model/stochastic/test_ensemble.py similarity index 82% rename from tests/modules/networks/model/stochastic/test_ensemble.py rename to tests/policy/modules/model/stochastic/test_ensemble.py index b57682e1..0ee8ae0e 100644 --- a/tests/modules/networks/model/stochastic/test_ensemble.py +++ b/tests/policy/modules/model/stochastic/test_ensemble.py @@ -5,10 +5,8 @@ @pytest.fixture(scope="module", params=(True, False), ids=lambda x: f"Forked({x})") def module_cls(request): - from raylab.modules.networks.model.stochastic.ensemble import ( - StochasticModelEnsemble, - ) - from raylab.modules.networks.model.stochastic.ensemble import ( + from raylab.policy.modules.model.stochastic.ensemble import StochasticModelEnsemble + from raylab.policy.modules.model.stochastic.ensemble import ( ForkedStochasticModelEnsemble, ) @@ -22,7 +20,7 @@ def ensemble_size(request): @pytest.fixture def build_single(obs_space, action_space): - from raylab.modules.networks.model.stochastic.single import MLPModel + from raylab.policy.modules.model.stochastic.single import MLPModel spec = MLPModel.spec_cls() input_dependent_scale = True diff --git a/tests/modules/networks/model/stochastic/test_single.py b/tests/policy/modules/model/stochastic/test_single.py similarity index 94% rename from tests/modules/networks/model/stochastic/test_single.py rename to tests/policy/modules/model/stochastic/test_single.py index be87c779..d4a25239 100644 --- a/tests/modules/networks/model/stochastic/test_single.py +++ b/tests/policy/modules/model/stochastic/test_single.py @@ -6,8 +6,8 @@ @pytest.fixture(scope="module", params=(True, False), ids=lambda x: f"Residual({x})") def module_cls(request): - from raylab.modules.networks.model.stochastic.single import MLPModel - from raylab.modules.networks.model.stochastic.single import ResidualMLPModel + from raylab.policy.modules.model.stochastic.single import MLPModel + from raylab.policy.modules.model.stochastic.single import ResidualMLPModel return ResidualMLPModel if request.param else MLPModel diff --git a/tests/modules/networks/model/stochastic/__init__.py b/tests/policy/modules/networks/__init__.py similarity index 100% rename from tests/modules/networks/model/stochastic/__init__.py rename to tests/policy/modules/networks/__init__.py diff --git a/tests/modules/networks/test_mlp.py b/tests/policy/modules/networks/test_mlp.py similarity index 90% rename from tests/modules/networks/test_mlp.py rename to tests/policy/modules/networks/test_mlp.py index c7c01d98..71ff0a0f 100644 --- a/tests/modules/networks/test_mlp.py +++ b/tests/policy/modules/networks/test_mlp.py @@ -2,7 +2,7 @@ import pytest import torch -from raylab.modules.networks import MLP +from raylab.policy.modules.networks import MLP PARAMS = (None, {}, {"state": torch.randn(10, 4)}) diff --git a/tests/modules/networks/test_resnet.py b/tests/policy/modules/networks/test_resnet.py similarity index 89% rename from tests/modules/networks/test_resnet.py rename to tests/policy/modules/networks/test_resnet.py index b64d04bd..77d03d9d 100644 --- a/tests/modules/networks/test_resnet.py +++ b/tests/policy/modules/networks/test_resnet.py @@ -2,7 +2,7 @@ import pytest import torch -from raylab.modules.networks import ResidualNet +from raylab.policy.modules.networks import ResidualNet PARAMS = (None, {}, {"state": torch.randn(10, 4)}) diff --git a/tests/modules/test_ddpg.py b/tests/policy/modules/test_ddpg.py similarity index 94% rename from tests/modules/test_ddpg.py rename to tests/policy/modules/test_ddpg.py index d21c1a50..7cd2c53d 100644 --- a/tests/modules/test_ddpg.py +++ b/tests/policy/modules/test_ddpg.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from raylab.modules.ddpg import DDPG +from raylab.policy.modules.ddpg import DDPG @pytest.fixture diff --git a/tests/modules/test_sac.py b/tests/policy/modules/test_sac.py similarity index 94% rename from tests/modules/test_sac.py rename to tests/policy/modules/test_sac.py index 4adddb29..aa3f611f 100644 --- a/tests/modules/test_sac.py +++ b/tests/policy/modules/test_sac.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn -from raylab.modules.sac import SAC +from raylab.policy.modules.sac import SAC @pytest.fixture diff --git a/tests/modules/v0/__init__.py b/tests/policy/modules/v0/__init__.py similarity index 100% rename from tests/modules/v0/__init__.py rename to tests/policy/modules/v0/__init__.py diff --git a/tests/modules/conftest.py b/tests/policy/modules/v0/conftest.py similarity index 92% rename from tests/modules/conftest.py rename to tests/policy/modules/v0/conftest.py index c8fcbee4..77f8c204 100644 --- a/tests/modules/conftest.py +++ b/tests/policy/modules/v0/conftest.py @@ -1,4 +1,4 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access +# pylint:disable=missing-docstring,redefined-outer-name,protected-access from functools import partial import pytest diff --git a/tests/modules/v0/test_action_value_mixin.py b/tests/policy/modules/v0/test_action_value_mixin.py similarity index 96% rename from tests/modules/v0/test_action_value_mixin.py rename to tests/policy/modules/v0/test_action_value_mixin.py index a230050d..18af7f92 100644 --- a/tests/modules/v0/test_action_value_mixin.py +++ b/tests/policy/modules/v0/test_action_value_mixin.py @@ -6,7 +6,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.v0.mixins import ActionValueMixin +from raylab.policy.modules.v0.mixins import ActionValueMixin class DummyModule(ActionValueMixin, nn.ModuleDict): diff --git a/tests/modules/v0/test_deterministic_actor_mixin.py b/tests/policy/modules/v0/test_deterministic_actor_mixin.py similarity index 98% rename from tests/modules/v0/test_deterministic_actor_mixin.py rename to tests/policy/modules/v0/test_deterministic_actor_mixin.py index 37ed59e5..0297762b 100644 --- a/tests/modules/v0/test_deterministic_actor_mixin.py +++ b/tests/policy/modules/v0/test_deterministic_actor_mixin.py @@ -5,7 +5,7 @@ from ray.rllib import SampleBatch from ray.rllib.utils import merge_dicts -from raylab.modules.v0.mixins import DeterministicActorMixin +from raylab.policy.modules.v0.mixins import DeterministicActorMixin BASE_CONFIG = { diff --git a/tests/modules/v0/test_naf_module.py b/tests/policy/modules/v0/test_naf_module.py similarity index 91% rename from tests/modules/v0/test_naf_module.py rename to tests/policy/modules/v0/test_naf_module.py index 95889065..a5ec2c51 100644 --- a/tests/modules/v0/test_naf_module.py +++ b/tests/policy/modules/v0/test_naf_module.py @@ -4,7 +4,7 @@ import pytest import torch -from raylab.modules.v0.naf_module import NAFModule +from raylab.policy.modules.v0.naf_module import NAFModule @pytest.fixture(params=(True, False), ids=("Double Q", "Single Q")) diff --git a/tests/modules/v0/test_normalizing_flow_actor_mixin.py b/tests/policy/modules/v0/test_normalizing_flow_actor_mixin.py similarity index 98% rename from tests/modules/v0/test_normalizing_flow_actor_mixin.py rename to tests/policy/modules/v0/test_normalizing_flow_actor_mixin.py index c61f0e2e..e8648b81 100644 --- a/tests/modules/v0/test_normalizing_flow_actor_mixin.py +++ b/tests/policy/modules/v0/test_normalizing_flow_actor_mixin.py @@ -6,7 +6,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.v0.mixins import NormalizingFlowActorMixin +from raylab.policy.modules.v0.mixins import NormalizingFlowActorMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/v0/test_normalizing_flow_model_mixin.py b/tests/policy/modules/v0/test_normalizing_flow_model_mixin.py similarity index 98% rename from tests/modules/v0/test_normalizing_flow_model_mixin.py rename to tests/policy/modules/v0/test_normalizing_flow_model_mixin.py index 98e590af..534c6cbd 100644 --- a/tests/modules/v0/test_normalizing_flow_model_mixin.py +++ b/tests/policy/modules/v0/test_normalizing_flow_model_mixin.py @@ -7,7 +7,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.v0.mixins import NormalizingFlowModelMixin +from raylab.policy.modules.v0.mixins import NormalizingFlowModelMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/v0/test_state_value_mixin.py b/tests/policy/modules/v0/test_state_value_mixin.py similarity index 96% rename from tests/modules/v0/test_state_value_mixin.py rename to tests/policy/modules/v0/test_state_value_mixin.py index 2b69f255..b30368c5 100644 --- a/tests/modules/v0/test_state_value_mixin.py +++ b/tests/policy/modules/v0/test_state_value_mixin.py @@ -6,7 +6,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.v0.mixins import StateValueMixin +from raylab.policy.modules.v0.mixins import StateValueMixin class DummyModule(StateValueMixin, nn.ModuleDict): diff --git a/tests/modules/v0/test_stochastic_actor_mixin.py b/tests/policy/modules/v0/test_stochastic_actor_mixin.py similarity index 98% rename from tests/modules/v0/test_stochastic_actor_mixin.py rename to tests/policy/modules/v0/test_stochastic_actor_mixin.py index 346303ba..4030da2e 100644 --- a/tests/modules/v0/test_stochastic_actor_mixin.py +++ b/tests/policy/modules/v0/test_stochastic_actor_mixin.py @@ -7,7 +7,7 @@ from gym.spaces import Discrete from ray.rllib import SampleBatch -from raylab.modules.v0.mixins import StochasticActorMixin +from raylab.policy.modules.v0.mixins import StochasticActorMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/v0/test_stochastic_model_mixin.py b/tests/policy/modules/v0/test_stochastic_model_mixin.py similarity index 98% rename from tests/modules/v0/test_stochastic_model_mixin.py rename to tests/policy/modules/v0/test_stochastic_model_mixin.py index 43cb9d9e..faaeb3c1 100644 --- a/tests/modules/v0/test_stochastic_model_mixin.py +++ b/tests/policy/modules/v0/test_stochastic_model_mixin.py @@ -4,7 +4,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.v0.mixins import StochasticModelMixin +from raylab.policy.modules.v0.mixins import StochasticModelMixin class DummyModule(StochasticModelMixin, nn.ModuleDict): diff --git a/tests/modules/v0/test_svg_module.py b/tests/policy/modules/v0/test_svg_module.py similarity index 95% rename from tests/modules/v0/test_svg_module.py rename to tests/policy/modules/v0/test_svg_module.py index 22c7b456..3271a045 100644 --- a/tests/modules/v0/test_svg_module.py +++ b/tests/policy/modules/v0/test_svg_module.py @@ -3,7 +3,7 @@ import torch from ray.rllib import SampleBatch -from raylab.modules.v0.svg_module import SVGModule +from raylab.policy.modules.v0.svg_module import SVGModule @pytest.fixture diff --git a/tests/modules/v0/test_trpo_extensions.py b/tests/policy/modules/v0/test_trpo_extensions.py similarity index 97% rename from tests/modules/v0/test_trpo_extensions.py rename to tests/policy/modules/v0/test_trpo_extensions.py index edd48d08..a520fd89 100644 --- a/tests/modules/v0/test_trpo_extensions.py +++ b/tests/policy/modules/v0/test_trpo_extensions.py @@ -5,7 +5,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.catalog import TRPOTang2018 +from raylab.policy.modules.v0.trpo_tang2018 import TRPOTang2018 from .utils import make_batch from .utils import make_module diff --git a/tests/modules/v0/utils.py b/tests/policy/modules/v0/utils.py similarity index 100% rename from tests/modules/v0/utils.py rename to tests/policy/modules/v0/utils.py diff --git a/tests/pytorch/nn/distributions/flows/test_couplings.py b/tests/pytorch/nn/distributions/flows/test_couplings.py index aff975f4..69e7088a 100644 --- a/tests/pytorch/nn/distributions/flows/test_couplings.py +++ b/tests/pytorch/nn/distributions/flows/test_couplings.py @@ -2,8 +2,8 @@ import pytest import torch -from raylab.modules.networks import MLP -from raylab.modules.networks import ResidualNet +from raylab.policy.modules.networks import MLP +from raylab.policy.modules.networks import ResidualNet from raylab.pytorch.nn.distributions.flows.coupling import AdditiveCouplingTransform from raylab.pytorch.nn.distributions.flows.coupling import AffineCouplingTransform from raylab.pytorch.nn.distributions.flows.coupling import PiecewiseRQSCouplingTransform From c915389920a592643914b4679081b002e42c2c6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 15:58:51 -0300 Subject: [PATCH 34/48] refactor: move losses to raylab.policy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/mage/policy.py | 6 +++--- raylab/agents/mapo/policy.py | 6 +++--- raylab/agents/mbpo/policy.py | 2 +- raylab/agents/naf/policy.py | 2 +- raylab/agents/sac/policy.py | 6 +++--- raylab/agents/sop/policy.py | 4 ++-- raylab/agents/svg/inf/policy.py | 2 +- raylab/agents/svg/one/policy.py | 2 +- raylab/agents/svg/policy.py | 4 ++-- raylab/agents/svg/soft/policy.py | 6 +++--- raylab/{ => policy}/losses/__init__.py | 0 raylab/{ => policy}/losses/abstract.py | 0 raylab/{ => policy}/losses/cdq_learning.py | 0 raylab/{ => policy}/losses/daml.py | 0 raylab/{ => policy}/losses/isfv_iteration.py | 0 raylab/{ => policy}/losses/mage.py | 0 raylab/{ => policy}/losses/mapo.py | 0 raylab/{ => policy}/losses/maximum_entropy.py | 0 raylab/{ => policy}/losses/mixins.py | 0 raylab/{ => policy}/losses/mle.py | 0 raylab/{ => policy}/losses/paml.py | 0 raylab/{ => policy}/losses/policy_gradient.py | 0 raylab/{ => policy}/losses/svg.py | 0 raylab/{ => policy}/losses/utils.py | 0 raylab/policy/model_based/training_mixin.py | 2 +- tests/agents/mage/test_policy.py | 8 ++++---- tests/agents/mapo/test_policy.py | 10 +++++----- tests/agents/sac/test_critics.py | 2 +- tests/{ => policy}/losses/__init__.py | 0 tests/{ => policy}/losses/conftest.py | 0 tests/{ => policy}/losses/test_cdq_learning.py | 2 +- tests/{ => policy}/losses/test_mage.py | 4 ++-- tests/{ => policy}/losses/test_mapo.py | 8 ++++---- tests/{ => policy}/losses/test_mle.py | 2 +- tests/{ => policy}/losses/test_paml.py | 6 +++--- 35 files changed, 42 insertions(+), 42 deletions(-) rename raylab/{ => policy}/losses/__init__.py (100%) rename raylab/{ => policy}/losses/abstract.py (100%) rename raylab/{ => policy}/losses/cdq_learning.py (100%) rename raylab/{ => policy}/losses/daml.py (100%) rename raylab/{ => policy}/losses/isfv_iteration.py (100%) rename raylab/{ => policy}/losses/mage.py (100%) rename raylab/{ => policy}/losses/mapo.py (100%) rename raylab/{ => policy}/losses/maximum_entropy.py (100%) rename raylab/{ => policy}/losses/mixins.py (100%) rename raylab/{ => policy}/losses/mle.py (100%) rename raylab/{ => policy}/losses/paml.py (100%) rename raylab/{ => policy}/losses/policy_gradient.py (100%) rename raylab/{ => policy}/losses/svg.py (100%) rename raylab/{ => policy}/losses/utils.py (100%) rename tests/{ => policy}/losses/__init__.py (100%) rename tests/{ => policy}/losses/conftest.py (100%) rename tests/{ => policy}/losses/test_cdq_learning.py (97%) rename tests/{ => policy}/losses/test_mage.py (96%) rename tests/{ => policy}/losses/test_mapo.py (93%) rename tests/{ => policy}/losses/test_mle.py (92%) rename tests/{ => policy}/losses/test_paml.py (96%) diff --git a/raylab/agents/mage/policy.py b/raylab/agents/mage/policy.py index 67368731..ae2cf2e3 100644 --- a/raylab/agents/mage/policy.py +++ b/raylab/agents/mage/policy.py @@ -1,10 +1,10 @@ """Policy for MAGE using PyTorch.""" from raylab.agents.sop import SOPTorchPolicy -from raylab.losses import MAGE -from raylab.losses import ModelEnsembleMLE -from raylab.losses.mage import MAGEModules from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.losses import MAGE +from raylab.policy.losses import ModelEnsembleMLE +from raylab.policy.losses.mage import MAGEModules from raylab.pytorch.optim import build_optimizer diff --git a/raylab/agents/mapo/policy.py b/raylab/agents/mapo/policy.py index aade2abd..a1491a13 100644 --- a/raylab/agents/mapo/policy.py +++ b/raylab/agents/mapo/policy.py @@ -2,11 +2,11 @@ from ray.rllib.utils import override from raylab.agents.sac import SACTorchPolicy -from raylab.losses import DAPO -from raylab.losses import MAPO -from raylab.losses import SPAML from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.losses import DAPO +from raylab.policy.losses import MAPO +from raylab.policy.losses import SPAML from raylab.pytorch.optim import build_optimizer diff --git a/raylab/agents/mbpo/policy.py b/raylab/agents/mbpo/policy.py index 774cce94..df48ffb3 100644 --- a/raylab/agents/mbpo/policy.py +++ b/raylab/agents/mbpo/policy.py @@ -2,10 +2,10 @@ from ray.rllib.utils import override from raylab.agents.sac import SACTorchPolicy -from raylab.losses import ModelEnsembleMLE from raylab.policy import EnvFnMixin from raylab.policy import ModelSamplingMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.losses import ModelEnsembleMLE from raylab.pytorch.optim import build_optimizer diff --git a/raylab/agents/naf/policy.py b/raylab/agents/naf/policy.py index 5931b204..09437323 100644 --- a/raylab/agents/naf/policy.py +++ b/raylab/agents/naf/policy.py @@ -3,9 +3,9 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.losses import ClippedDoubleQLearning from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.losses import ClippedDoubleQLearning from raylab.pytorch.optim import build_optimizer diff --git a/raylab/agents/sac/policy.py b/raylab/agents/sac/policy.py index fca42f63..661935c3 100644 --- a/raylab/agents/sac/policy.py +++ b/raylab/agents/sac/policy.py @@ -3,11 +3,11 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.losses import MaximumEntropyDual -from raylab.losses import ReparameterizedSoftPG -from raylab.losses import SoftCDQLearning from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.losses import MaximumEntropyDual +from raylab.policy.losses import ReparameterizedSoftPG +from raylab.policy.losses import SoftCDQLearning from raylab.pytorch.optim import build_optimizer diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index ef1435e0..95911c3e 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -3,10 +3,10 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.losses import ClippedDoubleQLearning -from raylab.losses import DeterministicPolicyGradient from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.losses import ClippedDoubleQLearning +from raylab.policy.losses import DeterministicPolicyGradient from raylab.pytorch.optim import build_optimizer diff --git a/raylab/agents/svg/inf/policy.py b/raylab/agents/svg/inf/policy.py index b018ce7b..0ec937f2 100644 --- a/raylab/agents/svg/inf/policy.py +++ b/raylab/agents/svg/inf/policy.py @@ -7,9 +7,9 @@ from ray.rllib.utils import override from raylab.agents.svg import SVGTorchPolicy -from raylab.losses import TrajectorySVG from raylab.policy import AdaptiveKLCoeffMixin from raylab.policy import EnvFnMixin +from raylab.policy.losses import TrajectorySVG from raylab.pytorch.optim import build_optimizer diff --git a/raylab/agents/svg/one/policy.py b/raylab/agents/svg/one/policy.py index 49cc7df1..c95c8a41 100644 --- a/raylab/agents/svg/one/policy.py +++ b/raylab/agents/svg/one/policy.py @@ -7,10 +7,10 @@ from ray.rllib.utils import override from raylab.agents.svg import SVGTorchPolicy -from raylab.losses import OneStepSVG from raylab.policy import AdaptiveKLCoeffMixin from raylab.policy import EnvFnMixin from raylab.policy import TorchPolicy +from raylab.policy.losses import OneStepSVG from raylab.pytorch.optim import get_optimizer_class diff --git a/raylab/agents/svg/policy.py b/raylab/agents/svg/policy.py index e9b5fe98..c264687e 100644 --- a/raylab/agents/svg/policy.py +++ b/raylab/agents/svg/policy.py @@ -2,11 +2,11 @@ import torch from ray.rllib import SampleBatch -from raylab.losses import ISFittedVIteration -from raylab.losses import MaximumLikelihood from raylab.policy import EnvFnMixin from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.losses import ISFittedVIteration +from raylab.policy.losses import MaximumLikelihood class SVGTorchPolicy(EnvFnMixin, TargetNetworksMixin, TorchPolicy): diff --git a/raylab/agents/svg/soft/policy.py b/raylab/agents/svg/soft/policy.py index f1f8e4a2..b17fa56e 100644 --- a/raylab/agents/svg/soft/policy.py +++ b/raylab/agents/svg/soft/policy.py @@ -5,10 +5,10 @@ from ray.rllib.utils import override from raylab.agents.svg import SVGTorchPolicy -from raylab.losses import ISSoftVIteration -from raylab.losses import MaximumEntropyDual -from raylab.losses import OneStepSoftSVG from raylab.policy import EnvFnMixin +from raylab.policy.losses import ISSoftVIteration +from raylab.policy.losses import MaximumEntropyDual +from raylab.policy.losses import OneStepSoftSVG from raylab.pytorch.optim import build_optimizer diff --git a/raylab/losses/__init__.py b/raylab/policy/losses/__init__.py similarity index 100% rename from raylab/losses/__init__.py rename to raylab/policy/losses/__init__.py diff --git a/raylab/losses/abstract.py b/raylab/policy/losses/abstract.py similarity index 100% rename from raylab/losses/abstract.py rename to raylab/policy/losses/abstract.py diff --git a/raylab/losses/cdq_learning.py b/raylab/policy/losses/cdq_learning.py similarity index 100% rename from raylab/losses/cdq_learning.py rename to raylab/policy/losses/cdq_learning.py diff --git a/raylab/losses/daml.py b/raylab/policy/losses/daml.py similarity index 100% rename from raylab/losses/daml.py rename to raylab/policy/losses/daml.py diff --git a/raylab/losses/isfv_iteration.py b/raylab/policy/losses/isfv_iteration.py similarity index 100% rename from raylab/losses/isfv_iteration.py rename to raylab/policy/losses/isfv_iteration.py diff --git a/raylab/losses/mage.py b/raylab/policy/losses/mage.py similarity index 100% rename from raylab/losses/mage.py rename to raylab/policy/losses/mage.py diff --git a/raylab/losses/mapo.py b/raylab/policy/losses/mapo.py similarity index 100% rename from raylab/losses/mapo.py rename to raylab/policy/losses/mapo.py diff --git a/raylab/losses/maximum_entropy.py b/raylab/policy/losses/maximum_entropy.py similarity index 100% rename from raylab/losses/maximum_entropy.py rename to raylab/policy/losses/maximum_entropy.py diff --git a/raylab/losses/mixins.py b/raylab/policy/losses/mixins.py similarity index 100% rename from raylab/losses/mixins.py rename to raylab/policy/losses/mixins.py diff --git a/raylab/losses/mle.py b/raylab/policy/losses/mle.py similarity index 100% rename from raylab/losses/mle.py rename to raylab/policy/losses/mle.py diff --git a/raylab/losses/paml.py b/raylab/policy/losses/paml.py similarity index 100% rename from raylab/losses/paml.py rename to raylab/policy/losses/paml.py diff --git a/raylab/losses/policy_gradient.py b/raylab/policy/losses/policy_gradient.py similarity index 100% rename from raylab/losses/policy_gradient.py rename to raylab/policy/losses/policy_gradient.py diff --git a/raylab/losses/svg.py b/raylab/policy/losses/svg.py similarity index 100% rename from raylab/losses/svg.py rename to raylab/policy/losses/svg.py diff --git a/raylab/losses/utils.py b/raylab/policy/losses/utils.py similarity index 100% rename from raylab/losses/utils.py rename to raylab/policy/losses/utils.py diff --git a/raylab/policy/model_based/training_mixin.py b/raylab/policy/model_based/training_mixin.py index 20676270..c232bc5b 100644 --- a/raylab/policy/model_based/training_mixin.py +++ b/raylab/policy/model_based/training_mixin.py @@ -20,7 +20,7 @@ from torch.utils.data import DataLoader from torch.utils.data import RandomSampler -from raylab.losses.abstract import Loss +from raylab.policy.losses.abstract import Loss from raylab.pytorch.utils import TensorDictDataset diff --git a/tests/agents/mage/test_policy.py b/tests/agents/mage/test_policy.py index 0234f1c6..e52511b2 100644 --- a/tests/agents/mage/test_policy.py +++ b/tests/agents/mage/test_policy.py @@ -6,9 +6,9 @@ import torch from raylab.agents.mage import MAGETorchPolicy -from raylab.losses import DeterministicPolicyGradient -from raylab.losses import MAGE -from raylab.losses import ModelEnsembleMLE +from raylab.policy.losses import DeterministicPolicyGradient +from raylab.policy.losses import MAGE +from raylab.policy.losses import ModelEnsembleMLE from raylab.utils.debug import fake_batch @@ -86,7 +86,7 @@ def test_learn_on_batch(policy, samples): def test_compile(policy): - with mock.patch("raylab.losses.MAGE.compile") as mocked_method: + with mock.patch("raylab.policy.losses.MAGE.compile") as mocked_method: policy.compile() assert isinstance(policy.module, torch.jit.ScriptModule) assert mocked_method.called diff --git a/tests/agents/mapo/test_policy.py b/tests/agents/mapo/test_policy.py index 19352d3e..62a29fcd 100644 --- a/tests/agents/mapo/test_policy.py +++ b/tests/agents/mapo/test_policy.py @@ -8,11 +8,11 @@ from raylab.agents.mapo import MAPOTorchPolicy from raylab.agents.sac import SACTorchPolicy -from raylab.losses import MAPO -from raylab.losses import SoftCDQLearning -from raylab.losses import SPAML from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.losses import MAPO +from raylab.policy.losses import SoftCDQLearning +from raylab.policy.losses import SPAML from raylab.utils.debug import fake_batch @@ -79,8 +79,8 @@ def test_learn_on_batch(policy, sample_batch): def test_compile(policy, mocker): - mapo = mocker.patch("raylab.losses.MAPO.compile") - spaml = mocker.patch("raylab.losses.SPAML.compile") + mapo = mocker.patch("raylab.policy.losses.MAPO.compile") + spaml = mocker.patch("raylab.policy.losses.SPAML.compile") policy.compile() assert isinstance(policy.module, torch.jit.ScriptModule) diff --git a/tests/agents/sac/test_critics.py b/tests/agents/sac/test_critics.py index 551b537c..6c19ef3f 100644 --- a/tests/agents/sac/test_critics.py +++ b/tests/agents/sac/test_critics.py @@ -5,7 +5,7 @@ from ray.rllib import SampleBatch import raylab.utils.dictionaries as dutil -from raylab.losses import SoftCDQLearning +from raylab.policy.losses import SoftCDQLearning @pytest.fixture(params=(True, False)) diff --git a/tests/losses/__init__.py b/tests/policy/losses/__init__.py similarity index 100% rename from tests/losses/__init__.py rename to tests/policy/losses/__init__.py diff --git a/tests/losses/conftest.py b/tests/policy/losses/conftest.py similarity index 100% rename from tests/losses/conftest.py rename to tests/policy/losses/conftest.py diff --git a/tests/losses/test_cdq_learning.py b/tests/policy/losses/test_cdq_learning.py similarity index 97% rename from tests/losses/test_cdq_learning.py rename to tests/policy/losses/test_cdq_learning.py index da70fa34..b170a3aa 100644 --- a/tests/losses/test_cdq_learning.py +++ b/tests/policy/losses/test_cdq_learning.py @@ -5,7 +5,7 @@ from ray.rllib import SampleBatch import raylab.utils.dictionaries as dutil -from raylab.losses import ClippedDoubleQLearning +from raylab.policy.losses import ClippedDoubleQLearning @pytest.fixture diff --git a/tests/losses/test_mage.py b/tests/policy/losses/test_mage.py similarity index 96% rename from tests/losses/test_mage.py rename to tests/policy/losses/test_mage.py index 7ca82577..536a43e9 100644 --- a/tests/losses/test_mage.py +++ b/tests/policy/losses/test_mage.py @@ -2,8 +2,8 @@ import pytest import torch -from raylab.losses.mage import MAGE -from raylab.losses.mage import MAGEModules +from raylab.policy.losses.mage import MAGE +from raylab.policy.losses.mage import MAGEModules @pytest.fixture diff --git a/tests/losses/test_mapo.py b/tests/policy/losses/test_mapo.py similarity index 93% rename from tests/losses/test_mapo.py rename to tests/policy/losses/test_mapo.py index 4332e596..aadb1e72 100644 --- a/tests/losses/test_mapo.py +++ b/tests/policy/losses/test_mapo.py @@ -2,10 +2,10 @@ import pytest import torch -from raylab.losses import DAPO -from raylab.losses import MAPO -from raylab.losses.abstract import Loss -from raylab.losses.mixins import EnvFunctionsMixin +from raylab.policy.losses import DAPO +from raylab.policy.losses import MAPO +from raylab.policy.losses.abstract import Loss +from raylab.policy.losses.mixins import EnvFunctionsMixin @pytest.fixture diff --git a/tests/losses/test_mle.py b/tests/policy/losses/test_mle.py similarity index 92% rename from tests/losses/test_mle.py rename to tests/policy/losses/test_mle.py index fd7365c8..ea6d6d6b 100644 --- a/tests/losses/test_mle.py +++ b/tests/policy/losses/test_mle.py @@ -2,7 +2,7 @@ import pytest import torch -from raylab.losses import ModelEnsembleMLE +from raylab.policy.losses import ModelEnsembleMLE @pytest.fixture diff --git a/tests/losses/test_paml.py b/tests/policy/losses/test_paml.py similarity index 96% rename from tests/losses/test_paml.py rename to tests/policy/losses/test_paml.py index 5c981731..62f1101f 100644 --- a/tests/losses/test_paml.py +++ b/tests/policy/losses/test_paml.py @@ -3,9 +3,9 @@ import torch from ray.rllib import SampleBatch -from raylab.losses import ModelEnsembleMLE -from raylab.losses import SPAML -from raylab.losses.abstract import Loss +from raylab.policy.losses import ModelEnsembleMLE +from raylab.policy.losses import SPAML +from raylab.policy.losses.abstract import Loss @pytest.fixture From 880e02e24a0cce156f8f329c1e7c9cda788ff76d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 16:30:05 -0300 Subject: [PATCH 35/48] refactor(tests): match test directory structure with package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- tests/conftest.py | 47 ++++-------------- tests/general/conftest.py | 16 ------ tests/general/test_rollout.py | 31 ------------ tests/general/test_worker.py | 20 -------- tests/{agents => raylab}/__init__.py | 0 .../mage => raylab/agents}/__init__.py | 0 tests/{ => raylab}/agents/conftest.py | 0 .../mapo => raylab/agents/mage}/__init__.py | 0 tests/{ => raylab}/agents/mage/test_policy.py | 0 .../{ => raylab}/agents/mage/test_trainer.py | 0 .../mbpo => raylab/agents/mapo}/__init__.py | 0 tests/{ => raylab}/agents/mapo/test_policy.py | 0 .../{ => raylab}/agents/mapo/test_trainer.py | 0 .../naf => raylab/agents/mbpo}/__init__.py | 0 tests/{ => raylab}/agents/mbpo/conftest.py | 0 tests/{ => raylab}/agents/mbpo/test_policy.py | 0 .../{ => raylab}/agents/mbpo/test_trainer.py | 0 .../sac => raylab/agents/naf}/__init__.py | 0 tests/{ => raylab}/agents/naf/test_policy.py | 0 .../sop => raylab/agents/sac}/__init__.py | 0 tests/{ => raylab}/agents/sac/conftest.py | 0 tests/{ => raylab}/agents/sac/test_actor.py | 0 tests/{ => raylab}/agents/sac/test_critics.py | 0 .../agents/sac/test_entropy_coeff.py | 0 .../svg => raylab/agents/sop}/__init__.py | 0 tests/{ => raylab}/agents/sop/test_policy.py | 0 tests/{cli => raylab/agents/svg}/__init__.py | 0 tests/{ => raylab}/agents/svg/conftest.py | 0 .../agents/svg/test_rollout_module.py | 0 tests/{ => raylab}/agents/svg/test_svg_one.py | 0 .../agents/svg/test_value_function.py | 0 .../agents/test_registry.py} | 49 ++++++++++++++++++- tests/{ => raylab}/agents/test_trainer.py | 0 tests/{envs => raylab/cli}/__init__.py | 0 tests/{ => raylab}/cli/test_cli.py | 0 tests/raylab/conftest.py | 38 ++++++++++++++ .../environments => raylab/envs}/__init__.py | 0 tests/{ => raylab}/envs/conftest.py | 0 .../envs/environments}/__init__.py | 0 .../environments/test_cartpole_swingup.py | 0 .../envs/environments/test_hvac.py | 0 .../envs/environments/test_navigation.py | 0 .../envs/environments/test_reservoir.py | 0 tests/{ => raylab}/envs/test_basic.py | 0 .../envs/test_gaussian_random_walks.py | 0 tests/{ => raylab}/envs/test_rewards.py | 0 tests/{ => raylab}/envs/test_termination.py | 0 tests/{ => raylab}/policy/__init__.py | 0 tests/{ => raylab}/policy/losses/__init__.py | 0 tests/{ => raylab}/policy/losses/conftest.py | 0 .../policy/losses/test_cdq_learning.py | 0 tests/{ => raylab}/policy/losses/test_mage.py | 0 tests/{ => raylab}/policy/losses/test_mapo.py | 0 tests/{ => raylab}/policy/losses/test_mle.py | 0 tests/{ => raylab}/policy/losses/test_paml.py | 0 .../policy/model_based/__init__.py | 0 .../policy/model_based/test_envfn_mixin.py | 0 .../policy/model_based/test_sampling_mixin.py | 0 .../policy/model_based/test_training_mixin.py | 0 tests/{ => raylab}/policy/modules/__init__.py | 0 .../policy/modules/actor/__init__.py | 0 .../policy/modules/actor/conftest.py | 0 .../policy/modules/actor/policy/__init__.py | 0 .../actor/policy/test_deterministic.py | 0 .../modules/actor/policy/test_stochastic.py | 0 .../modules/actor/test_deterministic.py | 0 .../policy/modules/actor/test_stochastic.py | 0 tests/{ => raylab}/policy/modules/conftest.py | 0 .../policy/modules/critic/__init__.py | 0 .../modules/critic/test_action_value.py | 0 .../policy/modules/model/__init__.py | 0 .../modules/model/stochastic/__init__.py | 0 .../modules/model/stochastic/conftest.py | 0 .../modules/model/stochastic/test_ensemble.py | 0 .../modules/model/stochastic/test_single.py | 0 .../policy/modules/networks/__init__.py | 0 .../policy/modules/networks/test_mlp.py | 0 .../policy/modules/networks/test_resnet.py | 0 .../{ => raylab}/policy/modules/test_ddpg.py | 0 tests/{ => raylab}/policy/modules/test_sac.py | 0 .../policy/modules/v0/__init__.py | 0 .../policy/modules/v0/conftest.py | 0 .../modules/v0/test_action_value_mixin.py | 0 .../v0/test_deterministic_actor_mixin.py | 0 .../policy/modules/v0/test_naf_module.py | 0 .../v0/test_normalizing_flow_actor_mixin.py | 0 .../v0/test_normalizing_flow_model_mixin.py | 0 .../modules/v0/test_state_value_mixin.py | 0 .../modules/v0/test_stochastic_actor_mixin.py | 0 .../modules/v0/test_stochastic_model_mixin.py | 0 .../policy/modules/v0/test_svg_module.py | 0 .../policy/modules/v0/test_trpo_extensions.py | 0 tests/{ => raylab}/policy/modules/v0/utils.py | 0 .../policy/test_optimizer_collection.py | 0 tests/{ => raylab}/pytorch/__init__.py | 0 tests/{ => raylab}/pytorch/conftest.py | 0 tests/{ => raylab}/pytorch/nn/__init__.py | 0 .../pytorch/nn/distributions/__init__.py | 0 .../pytorch/nn/distributions/conftest.py | 0 .../nn/distributions/flows/__init__.py | 0 .../flows/test_affine_constant.py | 0 .../nn/distributions/flows/test_couplings.py | 0 .../flows/test_flow_distribution.py | 0 .../nn/distributions/flows/test_maf.py | 0 .../nn/distributions/test_categorical.py | 0 .../pytorch/nn/distributions/test_normal.py | 0 .../nn/distributions/test_transforms.py | 0 .../pytorch/nn/distributions/test_uniform.py | 0 .../pytorch/nn/distributions/utils.py | 0 .../pytorch/nn/modules/__init__.py | 0 .../pytorch/nn/modules/test_action_output.py | 0 .../nn/modules/test_fully_connected.py | 0 .../pytorch/nn/modules/test_gaussian_noise.py | 0 .../pytorch/nn/modules/test_lambd.py | 0 .../pytorch/nn/modules/test_made.py | 0 .../nn/modules/test_normalized_linear.py | 0 .../pytorch/nn/modules/test_tanh_squash.py | 0 .../pytorch/nn/modules/test_tril_matrix.py | 0 tests/{ => raylab}/utils/__init__.py | 0 tests/{ => raylab}/utils/test_replay.py | 0 120 files changed, 95 insertions(+), 106 deletions(-) delete mode 100644 tests/general/conftest.py delete mode 100644 tests/general/test_rollout.py delete mode 100644 tests/general/test_worker.py rename tests/{agents => raylab}/__init__.py (100%) rename tests/{agents/mage => raylab/agents}/__init__.py (100%) rename tests/{ => raylab}/agents/conftest.py (100%) rename tests/{agents/mapo => raylab/agents/mage}/__init__.py (100%) rename tests/{ => raylab}/agents/mage/test_policy.py (100%) rename tests/{ => raylab}/agents/mage/test_trainer.py (100%) rename tests/{agents/mbpo => raylab/agents/mapo}/__init__.py (100%) rename tests/{ => raylab}/agents/mapo/test_policy.py (100%) rename tests/{ => raylab}/agents/mapo/test_trainer.py (100%) rename tests/{agents/naf => raylab/agents/mbpo}/__init__.py (100%) rename tests/{ => raylab}/agents/mbpo/conftest.py (100%) rename tests/{ => raylab}/agents/mbpo/test_policy.py (100%) rename tests/{ => raylab}/agents/mbpo/test_trainer.py (100%) rename tests/{agents/sac => raylab/agents/naf}/__init__.py (100%) rename tests/{ => raylab}/agents/naf/test_policy.py (100%) rename tests/{agents/sop => raylab/agents/sac}/__init__.py (100%) rename tests/{ => raylab}/agents/sac/conftest.py (100%) rename tests/{ => raylab}/agents/sac/test_actor.py (100%) rename tests/{ => raylab}/agents/sac/test_critics.py (100%) rename tests/{ => raylab}/agents/sac/test_entropy_coeff.py (100%) rename tests/{agents/svg => raylab/agents/sop}/__init__.py (100%) rename tests/{ => raylab}/agents/sop/test_policy.py (100%) rename tests/{cli => raylab/agents/svg}/__init__.py (100%) rename tests/{ => raylab}/agents/svg/conftest.py (100%) rename tests/{ => raylab}/agents/svg/test_rollout_module.py (100%) rename tests/{ => raylab}/agents/svg/test_svg_one.py (100%) rename tests/{ => raylab}/agents/svg/test_value_function.py (100%) rename tests/{general/test_trainer.py => raylab/agents/test_registry.py} (52%) rename tests/{ => raylab}/agents/test_trainer.py (100%) rename tests/{envs => raylab/cli}/__init__.py (100%) rename tests/{ => raylab}/cli/test_cli.py (100%) create mode 100644 tests/raylab/conftest.py rename tests/{envs/environments => raylab/envs}/__init__.py (100%) rename tests/{ => raylab}/envs/conftest.py (100%) rename tests/{general => raylab/envs/environments}/__init__.py (100%) rename tests/{ => raylab}/envs/environments/test_cartpole_swingup.py (100%) rename tests/{ => raylab}/envs/environments/test_hvac.py (100%) rename tests/{ => raylab}/envs/environments/test_navigation.py (100%) rename tests/{ => raylab}/envs/environments/test_reservoir.py (100%) rename tests/{ => raylab}/envs/test_basic.py (100%) rename tests/{ => raylab}/envs/test_gaussian_random_walks.py (100%) rename tests/{ => raylab}/envs/test_rewards.py (100%) rename tests/{ => raylab}/envs/test_termination.py (100%) rename tests/{ => raylab}/policy/__init__.py (100%) rename tests/{ => raylab}/policy/losses/__init__.py (100%) rename tests/{ => raylab}/policy/losses/conftest.py (100%) rename tests/{ => raylab}/policy/losses/test_cdq_learning.py (100%) rename tests/{ => raylab}/policy/losses/test_mage.py (100%) rename tests/{ => raylab}/policy/losses/test_mapo.py (100%) rename tests/{ => raylab}/policy/losses/test_mle.py (100%) rename tests/{ => raylab}/policy/losses/test_paml.py (100%) rename tests/{ => raylab}/policy/model_based/__init__.py (100%) rename tests/{ => raylab}/policy/model_based/test_envfn_mixin.py (100%) rename tests/{ => raylab}/policy/model_based/test_sampling_mixin.py (100%) rename tests/{ => raylab}/policy/model_based/test_training_mixin.py (100%) rename tests/{ => raylab}/policy/modules/__init__.py (100%) rename tests/{ => raylab}/policy/modules/actor/__init__.py (100%) rename tests/{ => raylab}/policy/modules/actor/conftest.py (100%) rename tests/{ => raylab}/policy/modules/actor/policy/__init__.py (100%) rename tests/{ => raylab}/policy/modules/actor/policy/test_deterministic.py (100%) rename tests/{ => raylab}/policy/modules/actor/policy/test_stochastic.py (100%) rename tests/{ => raylab}/policy/modules/actor/test_deterministic.py (100%) rename tests/{ => raylab}/policy/modules/actor/test_stochastic.py (100%) rename tests/{ => raylab}/policy/modules/conftest.py (100%) rename tests/{ => raylab}/policy/modules/critic/__init__.py (100%) rename tests/{ => raylab}/policy/modules/critic/test_action_value.py (100%) rename tests/{ => raylab}/policy/modules/model/__init__.py (100%) rename tests/{ => raylab}/policy/modules/model/stochastic/__init__.py (100%) rename tests/{ => raylab}/policy/modules/model/stochastic/conftest.py (100%) rename tests/{ => raylab}/policy/modules/model/stochastic/test_ensemble.py (100%) rename tests/{ => raylab}/policy/modules/model/stochastic/test_single.py (100%) rename tests/{ => raylab}/policy/modules/networks/__init__.py (100%) rename tests/{ => raylab}/policy/modules/networks/test_mlp.py (100%) rename tests/{ => raylab}/policy/modules/networks/test_resnet.py (100%) rename tests/{ => raylab}/policy/modules/test_ddpg.py (100%) rename tests/{ => raylab}/policy/modules/test_sac.py (100%) rename tests/{ => raylab}/policy/modules/v0/__init__.py (100%) rename tests/{ => raylab}/policy/modules/v0/conftest.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_action_value_mixin.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_deterministic_actor_mixin.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_naf_module.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_normalizing_flow_actor_mixin.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_normalizing_flow_model_mixin.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_state_value_mixin.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_stochastic_actor_mixin.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_stochastic_model_mixin.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_svg_module.py (100%) rename tests/{ => raylab}/policy/modules/v0/test_trpo_extensions.py (100%) rename tests/{ => raylab}/policy/modules/v0/utils.py (100%) rename tests/{ => raylab}/policy/test_optimizer_collection.py (100%) rename tests/{ => raylab}/pytorch/__init__.py (100%) rename tests/{ => raylab}/pytorch/conftest.py (100%) rename tests/{ => raylab}/pytorch/nn/__init__.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/__init__.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/conftest.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/flows/__init__.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/flows/test_affine_constant.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/flows/test_couplings.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/flows/test_flow_distribution.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/flows/test_maf.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/test_categorical.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/test_normal.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/test_transforms.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/test_uniform.py (100%) rename tests/{ => raylab}/pytorch/nn/distributions/utils.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/__init__.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_action_output.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_fully_connected.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_gaussian_noise.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_lambd.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_made.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_normalized_linear.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_tanh_squash.py (100%) rename tests/{ => raylab}/pytorch/nn/modules/test_tril_matrix.py (100%) rename tests/{ => raylab}/utils/__init__.py (100%) rename tests/{ => raylab}/utils/test_replay.py (100%) diff --git a/tests/conftest.py b/tests/conftest.py index b325a623..6575ab4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,8 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import logging - -import gym -import gym.spaces as spaces +# pylint:disable=missing-docstring,redefined-outer-name,protected-access import pytest from .mock_env import MockEnv -gym.logger.set_level(logging.ERROR) - # Test setup from: # https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option @@ -40,9 +34,16 @@ def init_ray(): ray.shutdown() +@pytest.fixture(autouse=True, scope="session") +def disable_gym_logger_warnings(): + import logging + import gym + + gym.logger.set_level(logging.ERROR) + + @pytest.fixture(autouse=True, scope="session") def register_envs(): - # pylint:disable=import-outside-toplevel import raylab from raylab.envs.registry import ENVS @@ -51,33 +52,3 @@ def _mock_env_maker(config): ENVS["MockEnv"] = _mock_env_maker raylab.register_all_environments() - - -@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Obs1Dim", "Obs4Dim")) -def obs_space(request): - return spaces.Box(-10, 10, shape=request.param) - - -@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Act1Dim", "Act4Dim")) -def action_space(request): - return spaces.Box(-1, 1, shape=request.param) - - -@pytest.fixture(scope="module") -def envs(): - from raylab.envs.registry import ENVS # pylint:disable=import-outside-toplevel - - return ENVS.copy() - - -ENV_IDS = ("MockEnv", "Navigation", "Reservoir", "HVAC", "MountainCarContinuous-v0") - - -@pytest.fixture(params=ENV_IDS) -def env_name(request): - return request.param - - -@pytest.fixture -def env_creator(envs, env_name): - return envs[env_name] diff --git a/tests/general/conftest.py b/tests/general/conftest.py deleted file mode 100644 index e264fbb1..00000000 --- a/tests/general/conftest.py +++ /dev/null @@ -1,16 +0,0 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import pytest - -from raylab.agents.registry import AGENTS - -TRAINER_NAMES, TRAINER_IMPORTS = zip(*AGENTS.items()) - - -@pytest.fixture(scope="module", params=TRAINER_IMPORTS, ids=TRAINER_NAMES) -def trainer_cls(request): - return request.param() - - -@pytest.fixture -def policy_cls(trainer_cls): - return trainer_cls._policy diff --git a/tests/general/test_rollout.py b/tests/general/test_rollout.py deleted file mode 100644 index c66b6fca..00000000 --- a/tests/general/test_rollout.py +++ /dev/null @@ -1,31 +0,0 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import pytest -from ray.rllib import RolloutWorker - - -@pytest.fixture -def worker_kwargs(): - return {"rollout_fragment_length": 200, "batch_mode": "truncate_episodes"} - - -def test_compute_single_action(envs, env_name, policy_cls): - env = envs[env_name]({}) - policy = policy_cls(env.observation_space, env.action_space, {"env": env_name}) - - obs = env.observation_space.sample() - action, states, info = policy.compute_single_action(obs, []) - assert action in env.action_space - assert isinstance(states, list) - assert isinstance(info, dict) - - -def test_policy_in_rollout_worker(envs, env_name, policy_cls, worker_kwargs): - env_creator = envs[env_name] - policy_config = {"env": env_name} - worker = RolloutWorker( - env_creator=env_creator, - policy=policy_cls, - policy_config=policy_config, - **worker_kwargs - ) - worker.sample() diff --git a/tests/general/test_worker.py b/tests/general/test_worker.py deleted file mode 100644 index e78afdf9..00000000 --- a/tests/general/test_worker.py +++ /dev/null @@ -1,20 +0,0 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import pytest -from ray.rllib import RolloutWorker -from ray.rllib import SampleBatch - - -@pytest.fixture -def worker(envs, env_name, policy_cls): - return RolloutWorker( - env_creator=envs[env_name], - policy=policy_cls, - policy_config={"env": env_name}, - rollout_fragment_length=200, - batch_mode="truncate_episodes", - ) - - -def test_collect_traj(worker): - traj = worker.sample() - assert isinstance(traj, SampleBatch) diff --git a/tests/agents/__init__.py b/tests/raylab/__init__.py similarity index 100% rename from tests/agents/__init__.py rename to tests/raylab/__init__.py diff --git a/tests/agents/mage/__init__.py b/tests/raylab/agents/__init__.py similarity index 100% rename from tests/agents/mage/__init__.py rename to tests/raylab/agents/__init__.py diff --git a/tests/agents/conftest.py b/tests/raylab/agents/conftest.py similarity index 100% rename from tests/agents/conftest.py rename to tests/raylab/agents/conftest.py diff --git a/tests/agents/mapo/__init__.py b/tests/raylab/agents/mage/__init__.py similarity index 100% rename from tests/agents/mapo/__init__.py rename to tests/raylab/agents/mage/__init__.py diff --git a/tests/agents/mage/test_policy.py b/tests/raylab/agents/mage/test_policy.py similarity index 100% rename from tests/agents/mage/test_policy.py rename to tests/raylab/agents/mage/test_policy.py diff --git a/tests/agents/mage/test_trainer.py b/tests/raylab/agents/mage/test_trainer.py similarity index 100% rename from tests/agents/mage/test_trainer.py rename to tests/raylab/agents/mage/test_trainer.py diff --git a/tests/agents/mbpo/__init__.py b/tests/raylab/agents/mapo/__init__.py similarity index 100% rename from tests/agents/mbpo/__init__.py rename to tests/raylab/agents/mapo/__init__.py diff --git a/tests/agents/mapo/test_policy.py b/tests/raylab/agents/mapo/test_policy.py similarity index 100% rename from tests/agents/mapo/test_policy.py rename to tests/raylab/agents/mapo/test_policy.py diff --git a/tests/agents/mapo/test_trainer.py b/tests/raylab/agents/mapo/test_trainer.py similarity index 100% rename from tests/agents/mapo/test_trainer.py rename to tests/raylab/agents/mapo/test_trainer.py diff --git a/tests/agents/naf/__init__.py b/tests/raylab/agents/mbpo/__init__.py similarity index 100% rename from tests/agents/naf/__init__.py rename to tests/raylab/agents/mbpo/__init__.py diff --git a/tests/agents/mbpo/conftest.py b/tests/raylab/agents/mbpo/conftest.py similarity index 100% rename from tests/agents/mbpo/conftest.py rename to tests/raylab/agents/mbpo/conftest.py diff --git a/tests/agents/mbpo/test_policy.py b/tests/raylab/agents/mbpo/test_policy.py similarity index 100% rename from tests/agents/mbpo/test_policy.py rename to tests/raylab/agents/mbpo/test_policy.py diff --git a/tests/agents/mbpo/test_trainer.py b/tests/raylab/agents/mbpo/test_trainer.py similarity index 100% rename from tests/agents/mbpo/test_trainer.py rename to tests/raylab/agents/mbpo/test_trainer.py diff --git a/tests/agents/sac/__init__.py b/tests/raylab/agents/naf/__init__.py similarity index 100% rename from tests/agents/sac/__init__.py rename to tests/raylab/agents/naf/__init__.py diff --git a/tests/agents/naf/test_policy.py b/tests/raylab/agents/naf/test_policy.py similarity index 100% rename from tests/agents/naf/test_policy.py rename to tests/raylab/agents/naf/test_policy.py diff --git a/tests/agents/sop/__init__.py b/tests/raylab/agents/sac/__init__.py similarity index 100% rename from tests/agents/sop/__init__.py rename to tests/raylab/agents/sac/__init__.py diff --git a/tests/agents/sac/conftest.py b/tests/raylab/agents/sac/conftest.py similarity index 100% rename from tests/agents/sac/conftest.py rename to tests/raylab/agents/sac/conftest.py diff --git a/tests/agents/sac/test_actor.py b/tests/raylab/agents/sac/test_actor.py similarity index 100% rename from tests/agents/sac/test_actor.py rename to tests/raylab/agents/sac/test_actor.py diff --git a/tests/agents/sac/test_critics.py b/tests/raylab/agents/sac/test_critics.py similarity index 100% rename from tests/agents/sac/test_critics.py rename to tests/raylab/agents/sac/test_critics.py diff --git a/tests/agents/sac/test_entropy_coeff.py b/tests/raylab/agents/sac/test_entropy_coeff.py similarity index 100% rename from tests/agents/sac/test_entropy_coeff.py rename to tests/raylab/agents/sac/test_entropy_coeff.py diff --git a/tests/agents/svg/__init__.py b/tests/raylab/agents/sop/__init__.py similarity index 100% rename from tests/agents/svg/__init__.py rename to tests/raylab/agents/sop/__init__.py diff --git a/tests/agents/sop/test_policy.py b/tests/raylab/agents/sop/test_policy.py similarity index 100% rename from tests/agents/sop/test_policy.py rename to tests/raylab/agents/sop/test_policy.py diff --git a/tests/cli/__init__.py b/tests/raylab/agents/svg/__init__.py similarity index 100% rename from tests/cli/__init__.py rename to tests/raylab/agents/svg/__init__.py diff --git a/tests/agents/svg/conftest.py b/tests/raylab/agents/svg/conftest.py similarity index 100% rename from tests/agents/svg/conftest.py rename to tests/raylab/agents/svg/conftest.py diff --git a/tests/agents/svg/test_rollout_module.py b/tests/raylab/agents/svg/test_rollout_module.py similarity index 100% rename from tests/agents/svg/test_rollout_module.py rename to tests/raylab/agents/svg/test_rollout_module.py diff --git a/tests/agents/svg/test_svg_one.py b/tests/raylab/agents/svg/test_svg_one.py similarity index 100% rename from tests/agents/svg/test_svg_one.py rename to tests/raylab/agents/svg/test_svg_one.py diff --git a/tests/agents/svg/test_value_function.py b/tests/raylab/agents/svg/test_value_function.py similarity index 100% rename from tests/agents/svg/test_value_function.py rename to tests/raylab/agents/svg/test_value_function.py diff --git a/tests/general/test_trainer.py b/tests/raylab/agents/test_registry.py similarity index 52% rename from tests/general/test_trainer.py rename to tests/raylab/agents/test_registry.py index 331bc592..5dfe46a8 100644 --- a/tests/general/test_trainer.py +++ b/tests/raylab/agents/test_registry.py @@ -1,7 +1,23 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access +# pylint:disable=missing-docstring,redefined-outer-name,protected-access from collections import defaultdict import pytest +from ray.rllib import RolloutWorker +from ray.rllib import SampleBatch + +from raylab.agents.registry import AGENTS + +TRAINER_NAMES, TRAINER_IMPORTS = zip(*AGENTS.items()) + + +@pytest.fixture(scope="module", params=TRAINER_IMPORTS, ids=TRAINER_NAMES) +def trainer_cls(request): + return request.param() + + +@pytest.fixture +def policy_cls(trainer_cls): + return trainer_cls._policy CONFIG = defaultdict( @@ -54,3 +70,34 @@ def test_trainer_eval(trainer): def test_trainer_restore(trainer): obj = trainer.save_to_object() trainer.restore_from_object(obj) + + +@pytest.fixture +def worker_kwargs(): + return {"rollout_fragment_length": 200, "batch_mode": "truncate_episodes"} + + +@pytest.fixture +def worker(envs, env_name, policy_cls, worker_kwargs): + return RolloutWorker( + env_creator=envs[env_name], + policy=policy_cls, + policy_config={"env": env_name}, + **worker_kwargs, + ) + + +def test_compute_single_action(envs, env_name, policy_cls): + env = envs[env_name]({}) + policy = policy_cls(env.observation_space, env.action_space, {"env": env_name}) + + obs = env.observation_space.sample() + action, states, info = policy.compute_single_action(obs, []) + assert action in env.action_space + assert isinstance(states, list) + assert isinstance(info, dict) + + +def test_policy_in_rollout_worker(worker): + traj = worker.sample() + assert isinstance(traj, SampleBatch) diff --git a/tests/agents/test_trainer.py b/tests/raylab/agents/test_trainer.py similarity index 100% rename from tests/agents/test_trainer.py rename to tests/raylab/agents/test_trainer.py diff --git a/tests/envs/__init__.py b/tests/raylab/cli/__init__.py similarity index 100% rename from tests/envs/__init__.py rename to tests/raylab/cli/__init__.py diff --git a/tests/cli/test_cli.py b/tests/raylab/cli/test_cli.py similarity index 100% rename from tests/cli/test_cli.py rename to tests/raylab/cli/test_cli.py diff --git a/tests/raylab/conftest.py b/tests/raylab/conftest.py new file mode 100644 index 00000000..df25a1d2 --- /dev/null +++ b/tests/raylab/conftest.py @@ -0,0 +1,38 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import gym.spaces as spaces +import pytest + + +@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Obs1Dim", "Obs4Dim")) +def obs_space(request): + return spaces.Box(-10, 10, shape=request.param) + + +@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Act1Dim", "Act4Dim")) +def action_space(request): + return spaces.Box(-1, 1, shape=request.param) + + +@pytest.fixture(scope="module") +def envs(): + from raylab.envs.registry import ENVS # pylint:disable=import-outside-toplevel + + return ENVS.copy() + + +@pytest.fixture( + params=""" + MockEnv + Navigation + Reservoir + HVAC + MountainCarContinuous-v0 + """.split() +) +def env_name(request): + return request.param + + +@pytest.fixture +def env_creator(envs, env_name): + return envs[env_name] diff --git a/tests/envs/environments/__init__.py b/tests/raylab/envs/__init__.py similarity index 100% rename from tests/envs/environments/__init__.py rename to tests/raylab/envs/__init__.py diff --git a/tests/envs/conftest.py b/tests/raylab/envs/conftest.py similarity index 100% rename from tests/envs/conftest.py rename to tests/raylab/envs/conftest.py diff --git a/tests/general/__init__.py b/tests/raylab/envs/environments/__init__.py similarity index 100% rename from tests/general/__init__.py rename to tests/raylab/envs/environments/__init__.py diff --git a/tests/envs/environments/test_cartpole_swingup.py b/tests/raylab/envs/environments/test_cartpole_swingup.py similarity index 100% rename from tests/envs/environments/test_cartpole_swingup.py rename to tests/raylab/envs/environments/test_cartpole_swingup.py diff --git a/tests/envs/environments/test_hvac.py b/tests/raylab/envs/environments/test_hvac.py similarity index 100% rename from tests/envs/environments/test_hvac.py rename to tests/raylab/envs/environments/test_hvac.py diff --git a/tests/envs/environments/test_navigation.py b/tests/raylab/envs/environments/test_navigation.py similarity index 100% rename from tests/envs/environments/test_navigation.py rename to tests/raylab/envs/environments/test_navigation.py diff --git a/tests/envs/environments/test_reservoir.py b/tests/raylab/envs/environments/test_reservoir.py similarity index 100% rename from tests/envs/environments/test_reservoir.py rename to tests/raylab/envs/environments/test_reservoir.py diff --git a/tests/envs/test_basic.py b/tests/raylab/envs/test_basic.py similarity index 100% rename from tests/envs/test_basic.py rename to tests/raylab/envs/test_basic.py diff --git a/tests/envs/test_gaussian_random_walks.py b/tests/raylab/envs/test_gaussian_random_walks.py similarity index 100% rename from tests/envs/test_gaussian_random_walks.py rename to tests/raylab/envs/test_gaussian_random_walks.py diff --git a/tests/envs/test_rewards.py b/tests/raylab/envs/test_rewards.py similarity index 100% rename from tests/envs/test_rewards.py rename to tests/raylab/envs/test_rewards.py diff --git a/tests/envs/test_termination.py b/tests/raylab/envs/test_termination.py similarity index 100% rename from tests/envs/test_termination.py rename to tests/raylab/envs/test_termination.py diff --git a/tests/policy/__init__.py b/tests/raylab/policy/__init__.py similarity index 100% rename from tests/policy/__init__.py rename to tests/raylab/policy/__init__.py diff --git a/tests/policy/losses/__init__.py b/tests/raylab/policy/losses/__init__.py similarity index 100% rename from tests/policy/losses/__init__.py rename to tests/raylab/policy/losses/__init__.py diff --git a/tests/policy/losses/conftest.py b/tests/raylab/policy/losses/conftest.py similarity index 100% rename from tests/policy/losses/conftest.py rename to tests/raylab/policy/losses/conftest.py diff --git a/tests/policy/losses/test_cdq_learning.py b/tests/raylab/policy/losses/test_cdq_learning.py similarity index 100% rename from tests/policy/losses/test_cdq_learning.py rename to tests/raylab/policy/losses/test_cdq_learning.py diff --git a/tests/policy/losses/test_mage.py b/tests/raylab/policy/losses/test_mage.py similarity index 100% rename from tests/policy/losses/test_mage.py rename to tests/raylab/policy/losses/test_mage.py diff --git a/tests/policy/losses/test_mapo.py b/tests/raylab/policy/losses/test_mapo.py similarity index 100% rename from tests/policy/losses/test_mapo.py rename to tests/raylab/policy/losses/test_mapo.py diff --git a/tests/policy/losses/test_mle.py b/tests/raylab/policy/losses/test_mle.py similarity index 100% rename from tests/policy/losses/test_mle.py rename to tests/raylab/policy/losses/test_mle.py diff --git a/tests/policy/losses/test_paml.py b/tests/raylab/policy/losses/test_paml.py similarity index 100% rename from tests/policy/losses/test_paml.py rename to tests/raylab/policy/losses/test_paml.py diff --git a/tests/policy/model_based/__init__.py b/tests/raylab/policy/model_based/__init__.py similarity index 100% rename from tests/policy/model_based/__init__.py rename to tests/raylab/policy/model_based/__init__.py diff --git a/tests/policy/model_based/test_envfn_mixin.py b/tests/raylab/policy/model_based/test_envfn_mixin.py similarity index 100% rename from tests/policy/model_based/test_envfn_mixin.py rename to tests/raylab/policy/model_based/test_envfn_mixin.py diff --git a/tests/policy/model_based/test_sampling_mixin.py b/tests/raylab/policy/model_based/test_sampling_mixin.py similarity index 100% rename from tests/policy/model_based/test_sampling_mixin.py rename to tests/raylab/policy/model_based/test_sampling_mixin.py diff --git a/tests/policy/model_based/test_training_mixin.py b/tests/raylab/policy/model_based/test_training_mixin.py similarity index 100% rename from tests/policy/model_based/test_training_mixin.py rename to tests/raylab/policy/model_based/test_training_mixin.py diff --git a/tests/policy/modules/__init__.py b/tests/raylab/policy/modules/__init__.py similarity index 100% rename from tests/policy/modules/__init__.py rename to tests/raylab/policy/modules/__init__.py diff --git a/tests/policy/modules/actor/__init__.py b/tests/raylab/policy/modules/actor/__init__.py similarity index 100% rename from tests/policy/modules/actor/__init__.py rename to tests/raylab/policy/modules/actor/__init__.py diff --git a/tests/policy/modules/actor/conftest.py b/tests/raylab/policy/modules/actor/conftest.py similarity index 100% rename from tests/policy/modules/actor/conftest.py rename to tests/raylab/policy/modules/actor/conftest.py diff --git a/tests/policy/modules/actor/policy/__init__.py b/tests/raylab/policy/modules/actor/policy/__init__.py similarity index 100% rename from tests/policy/modules/actor/policy/__init__.py rename to tests/raylab/policy/modules/actor/policy/__init__.py diff --git a/tests/policy/modules/actor/policy/test_deterministic.py b/tests/raylab/policy/modules/actor/policy/test_deterministic.py similarity index 100% rename from tests/policy/modules/actor/policy/test_deterministic.py rename to tests/raylab/policy/modules/actor/policy/test_deterministic.py diff --git a/tests/policy/modules/actor/policy/test_stochastic.py b/tests/raylab/policy/modules/actor/policy/test_stochastic.py similarity index 100% rename from tests/policy/modules/actor/policy/test_stochastic.py rename to tests/raylab/policy/modules/actor/policy/test_stochastic.py diff --git a/tests/policy/modules/actor/test_deterministic.py b/tests/raylab/policy/modules/actor/test_deterministic.py similarity index 100% rename from tests/policy/modules/actor/test_deterministic.py rename to tests/raylab/policy/modules/actor/test_deterministic.py diff --git a/tests/policy/modules/actor/test_stochastic.py b/tests/raylab/policy/modules/actor/test_stochastic.py similarity index 100% rename from tests/policy/modules/actor/test_stochastic.py rename to tests/raylab/policy/modules/actor/test_stochastic.py diff --git a/tests/policy/modules/conftest.py b/tests/raylab/policy/modules/conftest.py similarity index 100% rename from tests/policy/modules/conftest.py rename to tests/raylab/policy/modules/conftest.py diff --git a/tests/policy/modules/critic/__init__.py b/tests/raylab/policy/modules/critic/__init__.py similarity index 100% rename from tests/policy/modules/critic/__init__.py rename to tests/raylab/policy/modules/critic/__init__.py diff --git a/tests/policy/modules/critic/test_action_value.py b/tests/raylab/policy/modules/critic/test_action_value.py similarity index 100% rename from tests/policy/modules/critic/test_action_value.py rename to tests/raylab/policy/modules/critic/test_action_value.py diff --git a/tests/policy/modules/model/__init__.py b/tests/raylab/policy/modules/model/__init__.py similarity index 100% rename from tests/policy/modules/model/__init__.py rename to tests/raylab/policy/modules/model/__init__.py diff --git a/tests/policy/modules/model/stochastic/__init__.py b/tests/raylab/policy/modules/model/stochastic/__init__.py similarity index 100% rename from tests/policy/modules/model/stochastic/__init__.py rename to tests/raylab/policy/modules/model/stochastic/__init__.py diff --git a/tests/policy/modules/model/stochastic/conftest.py b/tests/raylab/policy/modules/model/stochastic/conftest.py similarity index 100% rename from tests/policy/modules/model/stochastic/conftest.py rename to tests/raylab/policy/modules/model/stochastic/conftest.py diff --git a/tests/policy/modules/model/stochastic/test_ensemble.py b/tests/raylab/policy/modules/model/stochastic/test_ensemble.py similarity index 100% rename from tests/policy/modules/model/stochastic/test_ensemble.py rename to tests/raylab/policy/modules/model/stochastic/test_ensemble.py diff --git a/tests/policy/modules/model/stochastic/test_single.py b/tests/raylab/policy/modules/model/stochastic/test_single.py similarity index 100% rename from tests/policy/modules/model/stochastic/test_single.py rename to tests/raylab/policy/modules/model/stochastic/test_single.py diff --git a/tests/policy/modules/networks/__init__.py b/tests/raylab/policy/modules/networks/__init__.py similarity index 100% rename from tests/policy/modules/networks/__init__.py rename to tests/raylab/policy/modules/networks/__init__.py diff --git a/tests/policy/modules/networks/test_mlp.py b/tests/raylab/policy/modules/networks/test_mlp.py similarity index 100% rename from tests/policy/modules/networks/test_mlp.py rename to tests/raylab/policy/modules/networks/test_mlp.py diff --git a/tests/policy/modules/networks/test_resnet.py b/tests/raylab/policy/modules/networks/test_resnet.py similarity index 100% rename from tests/policy/modules/networks/test_resnet.py rename to tests/raylab/policy/modules/networks/test_resnet.py diff --git a/tests/policy/modules/test_ddpg.py b/tests/raylab/policy/modules/test_ddpg.py similarity index 100% rename from tests/policy/modules/test_ddpg.py rename to tests/raylab/policy/modules/test_ddpg.py diff --git a/tests/policy/modules/test_sac.py b/tests/raylab/policy/modules/test_sac.py similarity index 100% rename from tests/policy/modules/test_sac.py rename to tests/raylab/policy/modules/test_sac.py diff --git a/tests/policy/modules/v0/__init__.py b/tests/raylab/policy/modules/v0/__init__.py similarity index 100% rename from tests/policy/modules/v0/__init__.py rename to tests/raylab/policy/modules/v0/__init__.py diff --git a/tests/policy/modules/v0/conftest.py b/tests/raylab/policy/modules/v0/conftest.py similarity index 100% rename from tests/policy/modules/v0/conftest.py rename to tests/raylab/policy/modules/v0/conftest.py diff --git a/tests/policy/modules/v0/test_action_value_mixin.py b/tests/raylab/policy/modules/v0/test_action_value_mixin.py similarity index 100% rename from tests/policy/modules/v0/test_action_value_mixin.py rename to tests/raylab/policy/modules/v0/test_action_value_mixin.py diff --git a/tests/policy/modules/v0/test_deterministic_actor_mixin.py b/tests/raylab/policy/modules/v0/test_deterministic_actor_mixin.py similarity index 100% rename from tests/policy/modules/v0/test_deterministic_actor_mixin.py rename to tests/raylab/policy/modules/v0/test_deterministic_actor_mixin.py diff --git a/tests/policy/modules/v0/test_naf_module.py b/tests/raylab/policy/modules/v0/test_naf_module.py similarity index 100% rename from tests/policy/modules/v0/test_naf_module.py rename to tests/raylab/policy/modules/v0/test_naf_module.py diff --git a/tests/policy/modules/v0/test_normalizing_flow_actor_mixin.py b/tests/raylab/policy/modules/v0/test_normalizing_flow_actor_mixin.py similarity index 100% rename from tests/policy/modules/v0/test_normalizing_flow_actor_mixin.py rename to tests/raylab/policy/modules/v0/test_normalizing_flow_actor_mixin.py diff --git a/tests/policy/modules/v0/test_normalizing_flow_model_mixin.py b/tests/raylab/policy/modules/v0/test_normalizing_flow_model_mixin.py similarity index 100% rename from tests/policy/modules/v0/test_normalizing_flow_model_mixin.py rename to tests/raylab/policy/modules/v0/test_normalizing_flow_model_mixin.py diff --git a/tests/policy/modules/v0/test_state_value_mixin.py b/tests/raylab/policy/modules/v0/test_state_value_mixin.py similarity index 100% rename from tests/policy/modules/v0/test_state_value_mixin.py rename to tests/raylab/policy/modules/v0/test_state_value_mixin.py diff --git a/tests/policy/modules/v0/test_stochastic_actor_mixin.py b/tests/raylab/policy/modules/v0/test_stochastic_actor_mixin.py similarity index 100% rename from tests/policy/modules/v0/test_stochastic_actor_mixin.py rename to tests/raylab/policy/modules/v0/test_stochastic_actor_mixin.py diff --git a/tests/policy/modules/v0/test_stochastic_model_mixin.py b/tests/raylab/policy/modules/v0/test_stochastic_model_mixin.py similarity index 100% rename from tests/policy/modules/v0/test_stochastic_model_mixin.py rename to tests/raylab/policy/modules/v0/test_stochastic_model_mixin.py diff --git a/tests/policy/modules/v0/test_svg_module.py b/tests/raylab/policy/modules/v0/test_svg_module.py similarity index 100% rename from tests/policy/modules/v0/test_svg_module.py rename to tests/raylab/policy/modules/v0/test_svg_module.py diff --git a/tests/policy/modules/v0/test_trpo_extensions.py b/tests/raylab/policy/modules/v0/test_trpo_extensions.py similarity index 100% rename from tests/policy/modules/v0/test_trpo_extensions.py rename to tests/raylab/policy/modules/v0/test_trpo_extensions.py diff --git a/tests/policy/modules/v0/utils.py b/tests/raylab/policy/modules/v0/utils.py similarity index 100% rename from tests/policy/modules/v0/utils.py rename to tests/raylab/policy/modules/v0/utils.py diff --git a/tests/policy/test_optimizer_collection.py b/tests/raylab/policy/test_optimizer_collection.py similarity index 100% rename from tests/policy/test_optimizer_collection.py rename to tests/raylab/policy/test_optimizer_collection.py diff --git a/tests/pytorch/__init__.py b/tests/raylab/pytorch/__init__.py similarity index 100% rename from tests/pytorch/__init__.py rename to tests/raylab/pytorch/__init__.py diff --git a/tests/pytorch/conftest.py b/tests/raylab/pytorch/conftest.py similarity index 100% rename from tests/pytorch/conftest.py rename to tests/raylab/pytorch/conftest.py diff --git a/tests/pytorch/nn/__init__.py b/tests/raylab/pytorch/nn/__init__.py similarity index 100% rename from tests/pytorch/nn/__init__.py rename to tests/raylab/pytorch/nn/__init__.py diff --git a/tests/pytorch/nn/distributions/__init__.py b/tests/raylab/pytorch/nn/distributions/__init__.py similarity index 100% rename from tests/pytorch/nn/distributions/__init__.py rename to tests/raylab/pytorch/nn/distributions/__init__.py diff --git a/tests/pytorch/nn/distributions/conftest.py b/tests/raylab/pytorch/nn/distributions/conftest.py similarity index 100% rename from tests/pytorch/nn/distributions/conftest.py rename to tests/raylab/pytorch/nn/distributions/conftest.py diff --git a/tests/pytorch/nn/distributions/flows/__init__.py b/tests/raylab/pytorch/nn/distributions/flows/__init__.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/__init__.py rename to tests/raylab/pytorch/nn/distributions/flows/__init__.py diff --git a/tests/pytorch/nn/distributions/flows/test_affine_constant.py b/tests/raylab/pytorch/nn/distributions/flows/test_affine_constant.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/test_affine_constant.py rename to tests/raylab/pytorch/nn/distributions/flows/test_affine_constant.py diff --git a/tests/pytorch/nn/distributions/flows/test_couplings.py b/tests/raylab/pytorch/nn/distributions/flows/test_couplings.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/test_couplings.py rename to tests/raylab/pytorch/nn/distributions/flows/test_couplings.py diff --git a/tests/pytorch/nn/distributions/flows/test_flow_distribution.py b/tests/raylab/pytorch/nn/distributions/flows/test_flow_distribution.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/test_flow_distribution.py rename to tests/raylab/pytorch/nn/distributions/flows/test_flow_distribution.py diff --git a/tests/pytorch/nn/distributions/flows/test_maf.py b/tests/raylab/pytorch/nn/distributions/flows/test_maf.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/test_maf.py rename to tests/raylab/pytorch/nn/distributions/flows/test_maf.py diff --git a/tests/pytorch/nn/distributions/test_categorical.py b/tests/raylab/pytorch/nn/distributions/test_categorical.py similarity index 100% rename from tests/pytorch/nn/distributions/test_categorical.py rename to tests/raylab/pytorch/nn/distributions/test_categorical.py diff --git a/tests/pytorch/nn/distributions/test_normal.py b/tests/raylab/pytorch/nn/distributions/test_normal.py similarity index 100% rename from tests/pytorch/nn/distributions/test_normal.py rename to tests/raylab/pytorch/nn/distributions/test_normal.py diff --git a/tests/pytorch/nn/distributions/test_transforms.py b/tests/raylab/pytorch/nn/distributions/test_transforms.py similarity index 100% rename from tests/pytorch/nn/distributions/test_transforms.py rename to tests/raylab/pytorch/nn/distributions/test_transforms.py diff --git a/tests/pytorch/nn/distributions/test_uniform.py b/tests/raylab/pytorch/nn/distributions/test_uniform.py similarity index 100% rename from tests/pytorch/nn/distributions/test_uniform.py rename to tests/raylab/pytorch/nn/distributions/test_uniform.py diff --git a/tests/pytorch/nn/distributions/utils.py b/tests/raylab/pytorch/nn/distributions/utils.py similarity index 100% rename from tests/pytorch/nn/distributions/utils.py rename to tests/raylab/pytorch/nn/distributions/utils.py diff --git a/tests/pytorch/nn/modules/__init__.py b/tests/raylab/pytorch/nn/modules/__init__.py similarity index 100% rename from tests/pytorch/nn/modules/__init__.py rename to tests/raylab/pytorch/nn/modules/__init__.py diff --git a/tests/pytorch/nn/modules/test_action_output.py b/tests/raylab/pytorch/nn/modules/test_action_output.py similarity index 100% rename from tests/pytorch/nn/modules/test_action_output.py rename to tests/raylab/pytorch/nn/modules/test_action_output.py diff --git a/tests/pytorch/nn/modules/test_fully_connected.py b/tests/raylab/pytorch/nn/modules/test_fully_connected.py similarity index 100% rename from tests/pytorch/nn/modules/test_fully_connected.py rename to tests/raylab/pytorch/nn/modules/test_fully_connected.py diff --git a/tests/pytorch/nn/modules/test_gaussian_noise.py b/tests/raylab/pytorch/nn/modules/test_gaussian_noise.py similarity index 100% rename from tests/pytorch/nn/modules/test_gaussian_noise.py rename to tests/raylab/pytorch/nn/modules/test_gaussian_noise.py diff --git a/tests/pytorch/nn/modules/test_lambd.py b/tests/raylab/pytorch/nn/modules/test_lambd.py similarity index 100% rename from tests/pytorch/nn/modules/test_lambd.py rename to tests/raylab/pytorch/nn/modules/test_lambd.py diff --git a/tests/pytorch/nn/modules/test_made.py b/tests/raylab/pytorch/nn/modules/test_made.py similarity index 100% rename from tests/pytorch/nn/modules/test_made.py rename to tests/raylab/pytorch/nn/modules/test_made.py diff --git a/tests/pytorch/nn/modules/test_normalized_linear.py b/tests/raylab/pytorch/nn/modules/test_normalized_linear.py similarity index 100% rename from tests/pytorch/nn/modules/test_normalized_linear.py rename to tests/raylab/pytorch/nn/modules/test_normalized_linear.py diff --git a/tests/pytorch/nn/modules/test_tanh_squash.py b/tests/raylab/pytorch/nn/modules/test_tanh_squash.py similarity index 100% rename from tests/pytorch/nn/modules/test_tanh_squash.py rename to tests/raylab/pytorch/nn/modules/test_tanh_squash.py diff --git a/tests/pytorch/nn/modules/test_tril_matrix.py b/tests/raylab/pytorch/nn/modules/test_tril_matrix.py similarity index 100% rename from tests/pytorch/nn/modules/test_tril_matrix.py rename to tests/raylab/pytorch/nn/modules/test_tril_matrix.py diff --git a/tests/utils/__init__.py b/tests/raylab/utils/__init__.py similarity index 100% rename from tests/utils/__init__.py rename to tests/raylab/utils/__init__.py diff --git a/tests/utils/test_replay.py b/tests/raylab/utils/test_replay.py similarity index 100% rename from tests/utils/test_replay.py rename to tests/raylab/utils/test_replay.py From 811ae35275650a5766b6cfc69269581c77e2360e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Tue, 30 Jun 2020 19:39:17 -0300 Subject: [PATCH 36/48] test: remove all uses of unittest.mock MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- tests/raylab/agents/mage/test_policy.py | 12 +++--- .../policy/model_based/test_envfn_mixin.py | 42 +++++++++---------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/tests/raylab/agents/mage/test_policy.py b/tests/raylab/agents/mage/test_policy.py index e52511b2..e436ae56 100644 --- a/tests/raylab/agents/mage/test_policy.py +++ b/tests/raylab/agents/mage/test_policy.py @@ -1,6 +1,4 @@ # pylint: disable=missing-docstring,redefined-outer-name,protected-access -from unittest import mock - import numpy as np import pytest import torch @@ -85,8 +83,8 @@ def test_learn_on_batch(policy, samples): assert np.isfinite(info["grad_norm(critics)"]) -def test_compile(policy): - with mock.patch("raylab.policy.losses.MAGE.compile") as mocked_method: - policy.compile() - assert isinstance(policy.module, torch.jit.ScriptModule) - assert mocked_method.called +def test_compile(policy, mocker): + method = mocker.spy(MAGE, "compile") + policy.compile() + assert isinstance(policy.module, torch.jit.ScriptModule) + assert method.called diff --git a/tests/raylab/policy/model_based/test_envfn_mixin.py b/tests/raylab/policy/model_based/test_envfn_mixin.py index e7be8810..e11b374b 100644 --- a/tests/raylab/policy/model_based/test_envfn_mixin.py +++ b/tests/raylab/policy/model_based/test_envfn_mixin.py @@ -1,6 +1,5 @@ # pylint:disable=missing-docstring,redefined-outer-name,protected-access import math -from unittest import mock import pytest import torch @@ -61,7 +60,8 @@ def test_init(policy): assert hasattr(policy, "dynamics_fn") -def test_set_reward_from_config(policy, envs): # pylint:disable=unused-argument +def test_set_reward_from_config(policy, mocker): + obs_space, action_space = policy.observation_space, policy.action_space batch_size = 10 obs = fake_space_samples(obs_space, batch_size=batch_size) @@ -69,9 +69,9 @@ def test_set_reward_from_config(policy, envs): # pylint:disable=unused-argument new_obs = fake_space_samples(obs_space, batch_size=batch_size) obs, act, new_obs = map(policy.convert_to_tensor, (obs, act, new_obs)) - with mock.patch("raylab.policy.EnvFnMixin._set_reward_hook") as hook: - policy.set_reward_from_config("MockEnv", {}) - assert hook.called + hook = mocker.spy(EnvFnMixin, "_set_reward_hook") + policy.set_reward_from_config("MockEnv", {}) + assert hook.called original_fn = get_reward_fn("MockEnv", {}) expected_rew = original_fn(obs, act, new_obs) @@ -80,7 +80,7 @@ def test_set_reward_from_config(policy, envs): # pylint:disable=unused-argument assert torch.allclose(rew, expected_rew) -def test_set_termination_from_config(policy, envs): # pylint:disable=unused-argument +def test_set_termination_from_config(policy, mocker): obs_space, action_space = policy.observation_space, policy.action_space batch_size = 10 obs = fake_space_samples(obs_space, batch_size=batch_size) @@ -88,9 +88,9 @@ def test_set_termination_from_config(policy, envs): # pylint:disable=unused-arg new_obs = fake_space_samples(obs_space, batch_size=batch_size) obs, act, new_obs = map(policy.convert_to_tensor, (obs, act, new_obs)) - with mock.patch("raylab.policy.EnvFnMixin._set_termination_hook") as hook: - policy.set_termination_from_config("MockEnv", {}) - assert hook.called + hook = mocker.spy(EnvFnMixin, "_set_termination_hook") + policy.set_termination_from_config("MockEnv", {}) + assert hook.called done = policy.termination_fn(obs, act, new_obs) assert torch.is_tensor(done) @@ -98,28 +98,28 @@ def test_set_termination_from_config(policy, envs): # pylint:disable=unused-arg assert done.shape == obs.shape[:-1] -def test_set_reward_from_callable(policy, reward_fn): - with mock.patch("raylab.policy.EnvFnMixin._set_reward_hook") as hook: - policy.set_reward_from_callable(reward_fn) - assert hook.called +def test_set_reward_from_callable(policy, reward_fn, mocker): + hook = mocker.spy(EnvFnMixin, "_set_reward_hook") + policy.set_reward_from_callable(reward_fn) + assert hook.called assert hasattr(policy, "reward_fn") assert policy.reward_fn is reward_fn -def test_set_termination_from_callable(policy, termination_fn): - with mock.patch("raylab.policy.EnvFnMixin._set_termination_hook") as hook: - policy.set_termination_from_callable(termination_fn) - assert hook.called +def test_set_termination_from_callable(policy, termination_fn, mocker): + hook = mocker.spy(EnvFnMixin, "_set_termination_hook") + policy.set_termination_from_callable(termination_fn) + assert hook.called assert hasattr(policy, "termination_fn") assert policy.termination_fn is termination_fn -def test_set_dynamics_from_callable(policy, dynamics_fn): - with mock.patch("raylab.policy.EnvFnMixin._set_dynamics_hook") as hook: - policy.set_dynamics_from_callable(dynamics_fn) - assert hook.called +def test_set_dynamics_from_callable(policy, dynamics_fn, mocker): + hook = mocker.spy(EnvFnMixin, "_set_dynamics_hook") + policy.set_dynamics_from_callable(dynamics_fn) + assert hook.called assert hasattr(policy, "dynamics_fn") assert policy.dynamics_fn is dynamics_fn From 5edf01a781332e758fbacc25c2d54df495164212 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Jul 2020 09:03:13 +0000 Subject: [PATCH 37/48] chore(deps-dev): bump pre-commit from 2.5.1 to 2.6.0 Bumps [pre-commit](https://github.com/pre-commit/pre-commit) from 2.5.1 to 2.6.0. - [Release notes](https://github.com/pre-commit/pre-commit/releases) - [Changelog](https://github.com/pre-commit/pre-commit/blob/master/CHANGELOG.md) - [Commits](https://github.com/pre-commit/pre-commit/compare/v2.5.1...v2.6.0) Signed-off-by: dependabot[bot] --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index c06d692a..69d1c0a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1128,7 +1128,7 @@ description = "A framework for managing and maintaining multi-language pre-commi name = "pre-commit" optional = false python-versions = ">=3.6.1" -version = "2.5.1" +version = "2.6.0" [package.dependencies] cfgv = ">=2.0.0" @@ -2028,7 +2028,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "a3ba29e6b9d210c4f16c69d0acb0470c21c10a039325b78cba7e1de32d0c231c" +content-hash = "4b8029fbc1de1ec8d50af443ac0d298a7b7ec72d7c228b653997500d3887a826" python-versions = "^3.7" [metadata.files] @@ -2712,8 +2712,8 @@ poetry-version = [ {file = "poetry_version-0.1.5-py2.py3-none-any.whl", hash = "sha256:ba259257640cd36c76375563a001b9e85c6f54d764cc56b3eed0f3b5cefb0317"}, ] pre-commit = [ - {file = "pre_commit-2.5.1-py2.py3-none-any.whl", hash = "sha256:c5c8fd4d0e1c363723aaf0a8f9cba0f434c160b48c4028f4bae6d219177945b3"}, - {file = "pre_commit-2.5.1.tar.gz", hash = "sha256:da463cf8f0e257f9af49047ba514f6b90dbd9b4f92f4c8847a3ccd36834874c7"}, + {file = "pre_commit-2.6.0-py2.py3-none-any.whl", hash = "sha256:e8b1315c585052e729ab7e99dcca5698266bedce9067d21dc909c23e3ceed626"}, + {file = "pre_commit-2.6.0.tar.gz", hash = "sha256:1657663fdd63a321a4a739915d7d03baedd555b25054449090f97bb0cb30a915"}, ] prometheus-client = [ {file = "prometheus_client-0.8.0-py2.py3-none-any.whl", hash = "sha256:983c7ac4b47478720db338f1491ef67a100b474e3bc7dafcbaefb7d0b8f9b01c"}, diff --git a/pyproject.toml b/pyproject.toml index 484a9570..cb4358fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ tox = "^3.16.1" sphinx = "^3.1.1" pytest = "^5.4.3" gym-cartpole-swingup = "^0.1.0" -pre-commit = "^2.5.1" +pre-commit = "^2.6.0" reorder-python-imports = "^2.3.1" mypy = "^0.782" coverage = "^5.1" From d68f494f28394e0aaa10df2d4eec70988be99bde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 07:14:55 -0300 Subject: [PATCH 38/48] feat(policy): add stochastic and deterministic action dist wrappers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/acktr/policy.py | 5 ++ raylab/agents/mage/policy.py | 2 + raylab/agents/mapo/policy.py | 2 + raylab/agents/mbpo/policy.py | 2 + raylab/agents/naf/policy.py | 3 + raylab/agents/sac/policy.py | 3 + raylab/agents/sop/policy.py | 3 + raylab/agents/svg/policy.py | 3 + raylab/agents/trpo/policy.py | 5 ++ raylab/policy/action_dist.py | 60 +++++++++++++++++-- .../modules/actor/policy/deterministic.py | 5 ++ raylab/policy/modules/v0/naf_module.py | 10 ++-- raylab/policy/torch_policy.py | 8 ++- raylab/utils/exploration/gaussian_noise.py | 20 ++----- 14 files changed, 103 insertions(+), 28 deletions(-) diff --git a/raylab/agents/acktr/policy.py b/raylab/agents/acktr/policy.py index 293fe1a2..4bac4f21 100644 --- a/raylab/agents/acktr/policy.py +++ b/raylab/agents/acktr/policy.py @@ -12,6 +12,7 @@ import raylab.utils.dictionaries as dutil from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.pytorch.nn.distributions import Normal from raylab.pytorch.optim import build_optimizer from raylab.pytorch.optim.hessian_free import line_search @@ -52,6 +53,10 @@ class ACKTRTorchPolicy(TorchPolicy): # pylint:disable=abstract-method + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dist_class = WrapStochasticPolicy + @staticmethod @override(TorchPolicy) def get_default_config(): diff --git a/raylab/agents/mage/policy.py b/raylab/agents/mage/policy.py index ae2cf2e3..3692a0c3 100644 --- a/raylab/agents/mage/policy.py +++ b/raylab/agents/mage/policy.py @@ -2,6 +2,7 @@ from raylab.agents.sop import SOPTorchPolicy from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.action_dist import WrapDeterministicPolicy from raylab.policy.losses import MAGE from raylab.policy.losses import ModelEnsembleMLE from raylab.policy.losses.mage import MAGEModules @@ -21,6 +22,7 @@ class MAGETorchPolicy(ModelTrainingMixin, EnvFnMixin, SOPTorchPolicy): def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) + self.dist_class = WrapDeterministicPolicy module = self.module self.loss_model = ModelEnsembleMLE(module.models) diff --git a/raylab/agents/mapo/policy.py b/raylab/agents/mapo/policy.py index a1491a13..77bcbf53 100644 --- a/raylab/agents/mapo/policy.py +++ b/raylab/agents/mapo/policy.py @@ -4,6 +4,7 @@ from raylab.agents.sac import SACTorchPolicy from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.policy.losses import DAPO from raylab.policy.losses import MAPO from raylab.policy.losses import SPAML @@ -17,6 +18,7 @@ class MAPOTorchPolicy(ModelTrainingMixin, EnvFnMixin, SACTorchPolicy): def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) + self.dist_class = WrapStochasticPolicy self.loss_model = SPAML( self.module.models, self.module.actor, self.module.critics diff --git a/raylab/agents/mbpo/policy.py b/raylab/agents/mbpo/policy.py index df48ffb3..6b61e9bd 100644 --- a/raylab/agents/mbpo/policy.py +++ b/raylab/agents/mbpo/policy.py @@ -5,6 +5,7 @@ from raylab.policy import EnvFnMixin from raylab.policy import ModelSamplingMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.policy.losses import ModelEnsembleMLE from raylab.pytorch.optim import build_optimizer @@ -18,6 +19,7 @@ class MBPOTorchPolicy( def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) + self.dist_class = WrapStochasticPolicy models = self.module.models self.loss_model = ModelEnsembleMLE(models) diff --git a/raylab/agents/naf/policy.py b/raylab/agents/naf/policy.py index 09437323..dee40f6f 100644 --- a/raylab/agents/naf/policy.py +++ b/raylab/agents/naf/policy.py @@ -5,6 +5,7 @@ from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapDeterministicPolicy from raylab.policy.losses import ClippedDoubleQLearning from raylab.pytorch.optim import build_optimizer @@ -16,6 +17,8 @@ class NAFTorchPolicy(TargetNetworksMixin, TorchPolicy): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.dist_class = WrapDeterministicPolicy + target_critics = [lambda s, _, v=v: v(s) for v in self.module.target_vcritics] self.loss_fn = ClippedDoubleQLearning( self.module.critics, target_critics, actor=lambda _: None, diff --git a/raylab/agents/sac/policy.py b/raylab/agents/sac/policy.py index 661935c3..89545529 100644 --- a/raylab/agents/sac/policy.py +++ b/raylab/agents/sac/policy.py @@ -5,6 +5,7 @@ from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.policy.losses import MaximumEntropyDual from raylab.policy.losses import ReparameterizedSoftPG from raylab.policy.losses import SoftCDQLearning @@ -18,6 +19,8 @@ class SACTorchPolicy(TargetNetworksMixin, TorchPolicy): def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) + self.dist_class = WrapStochasticPolicy + self.loss_actor = ReparameterizedSoftPG(self.module.actor, self.module.critics) self.loss_critic = SoftCDQLearning( self.module.critics, self.module.target_critics, self.module.actor.sample diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index 95911c3e..97025ccc 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -5,6 +5,7 @@ from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapDeterministicPolicy from raylab.policy.losses import ClippedDoubleQLearning from raylab.policy.losses import DeterministicPolicyGradient from raylab.pytorch.optim import build_optimizer @@ -17,6 +18,8 @@ class SOPTorchPolicy(TargetNetworksMixin, TorchPolicy): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.dist_class = WrapDeterministicPolicy + self.loss_actor = DeterministicPolicyGradient( self.module.actor, self.module.critics, ) diff --git a/raylab/agents/svg/policy.py b/raylab/agents/svg/policy.py index c264687e..7dcc4a52 100644 --- a/raylab/agents/svg/policy.py +++ b/raylab/agents/svg/policy.py @@ -5,6 +5,7 @@ from raylab.policy import EnvFnMixin from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.policy.losses import ISFittedVIteration from raylab.policy.losses import MaximumLikelihood @@ -15,6 +16,8 @@ class SVGTorchPolicy(EnvFnMixin, TargetNetworksMixin, TorchPolicy): # pylint: disable=abstract-method def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) + self.dist_class = WrapStochasticPolicy + self.loss_model = MaximumLikelihood(self.module.model) self.loss_critic = ISFittedVIteration( self.module.critic, self.module.target_critic diff --git a/raylab/agents/trpo/policy.py b/raylab/agents/trpo/policy.py index d8d3d691..702bd6db 100644 --- a/raylab/agents/trpo/policy.py +++ b/raylab/agents/trpo/policy.py @@ -11,6 +11,7 @@ from torch.nn.utils import vector_to_parameters from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.pytorch.optim import build_optimizer from raylab.pytorch.optim.hessian_free import conjugate_gradient from raylab.pytorch.optim.hessian_free import hessian_vector_product @@ -25,6 +26,10 @@ class TRPOTorchPolicy(TorchPolicy): # pylint:disable=abstract-method + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dist_class = WrapStochasticPolicy + @staticmethod @override(TorchPolicy) def get_default_config(): diff --git a/raylab/policy/action_dist.py b/raylab/policy/action_dist.py index d6baeb0e..d8c99861 100644 --- a/raylab/policy/action_dist.py +++ b/raylab/policy/action_dist.py @@ -1,15 +1,26 @@ """Action distribution for compatibility with RLlib's interface.""" +import torch from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils import override +from torch import Tensor +from .modules.actor.policy.deterministic import DeterministicPolicy +from .modules.actor.policy.stochastic import StochasticPolicy +from .modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy as V0StochasticPi -class WrapModuleDist(ActionDistribution): - """Stores a nn.Module and inputs, delegation all methods to the module.""" + +class WrapStochasticPolicy(ActionDistribution): + """Wraps an nn.Module with a stochastic actor and its inputs. + + Expects actor to be a StochasticPolicy instance. + """ # pylint:disable=abstract-method def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + assert hasattr(self.model, "actor") + assert isinstance(self.model.actor, (V0StochasticPi, StochasticPolicy)) self._sampled_logp = None @override(ActionDistribution) @@ -20,9 +31,7 @@ def sample(self): @override(ActionDistribution) def deterministic_sample(self): - if hasattr(self.model.actor, "deterministic"): - return self.model.actor.deterministic(**self.inputs) - return self.model.actor(**self.inputs), None + return self.model.actor.deterministic(**self.inputs) @override(ActionDistribution) def sampled_action_logp(self): @@ -35,3 +44,44 @@ def logp(self, x): @override(ActionDistribution) def entropy(self): return self.model.actor.entropy(**self.inputs) + + +class WrapDeterministicPolicy(ActionDistribution): + """Wraps an nn.Module with a deterministic actor and its inputs. + + Expects actor to be a DeterministicPolicy instance. + """ + + # pylint:disable=abstract-method + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert hasattr(self.model, "actor") and isinstance( + self.model.actor, DeterministicPolicy + ) + assert hasattr(self.model, "behavior") and isinstance( + self.model.behavior, DeterministicPolicy + ) + + @override(ActionDistribution) + def sample(self): + action = self.model.behavior(**self.inputs) + return action, None + + def sample_inject_noise(self, noise_stddev: float) -> Tensor: + """Add zero-mean Gaussian noise to the actions prior to normalizing them.""" + unconstrained_action = self.model.behavior.unconstrained_action(**self.inputs) + unconstrained_action += torch.randn_like(unconstrained_action) * noise_stddev + return self.model.behavior.squash_action(unconstrained_action), None + + @override(ActionDistribution) + def deterministic_sample(self): + return self.model.actor(**self.inputs), None + + @override(ActionDistribution) + def sampled_action_logp(self): + return None + + @override(ActionDistribution) + def logp(self, x): + return None diff --git a/raylab/policy/modules/actor/policy/deterministic.py b/raylab/policy/modules/actor/policy/deterministic.py index 457b5f53..b9d50bd1 100644 --- a/raylab/policy/modules/actor/policy/deterministic.py +++ b/raylab/policy/modules/actor/policy/deterministic.py @@ -52,6 +52,11 @@ def unconstrained_action(self, obs: Tensor) -> Tensor: unconstrained_action = self.noise(unconstrained_action) return unconstrained_action + @torch.jit.export + def squash_action(self, unconstrained_action: Tensor) -> Tensor: + """Returns the action generated by the given unconstrained action.""" + return self.squashing(unconstrained_action) + @torch.jit.export def unsquash_action(self, action: Tensor) -> Tensor: """Returns the unconstrained action which generated the given action.""" diff --git a/raylab/policy/modules/v0/naf_module.py b/raylab/policy/modules/v0/naf_module.py index e57f745b..5f50c61b 100644 --- a/raylab/policy/modules/v0/naf_module.py +++ b/raylab/policy/modules/v0/naf_module.py @@ -5,14 +5,13 @@ import torch.nn as nn from ray.rllib.utils import override +from raylab.policy.modules.actor.policy.deterministic import DeterministicPolicy from raylab.pytorch.nn import FullyConnected from raylab.pytorch.nn import NormalizedLinear from raylab.pytorch.nn import TanhSquash from raylab.pytorch.nn import TrilMatrix from raylab.utils.dictionaries import deep_merge -from .mixins import DeterministicPolicy - BASE_CONFIG = { "double_q": False, @@ -68,8 +67,7 @@ def _make_encoder(obs_space, config): def _make_actor(self, obs_space, action_space, config): naf = self.critics[0] - mods = nn.ModuleList([naf.logits, naf.pre_act, naf.squash]) - actor = DeterministicPolicy(mods) + actor = DeterministicPolicy(naf.logits, naf.pre_act, naf.squash) behavior = actor if config["perturbed_policy"]: if not config["encoder"].get("layer_norm"): @@ -77,7 +75,9 @@ def _make_actor(self, obs_space, action_space, config): "'layer_norm' is deactivated even though a perturbed policy was " "requested. For optimal stability, set 'layer_norm': True." ) - behavior = DeterministicPolicy.from_scratch(obs_space, action_space, config) + logits = self._make_encoder(obs_space, config) + _naf = NAF(logits, action_space, config) + behavior = DeterministicPolicy(_naf.logits, _naf.pre_act, _naf.squash) return {"actor": actor, "behavior": behavior} diff --git a/raylab/policy/torch_policy.py b/raylab/policy/torch_policy.py index 58bc3fa5..e5a5ec88 100644 --- a/raylab/policy/torch_policy.py +++ b/raylab/policy/torch_policy.py @@ -25,7 +25,6 @@ from raylab.pytorch.utils import convert_to_tensor from raylab.utils.dictionaries import deep_merge -from .action_dist import WrapModuleDist from .modules.catalog import get_module from .optimizer_collection import OptimizerCollection @@ -34,6 +33,8 @@ class TorchPolicy(Policy): """A Policy that uses PyTorch as a backend. Attributes: + dist_class: Action distribution class for computing actions. Must be set + by subclasses. device: Device in which the parameter tensors reside. All input samples will be converted to tensors and moved to this device module: The policy's neural network module. Should be compilable to @@ -64,7 +65,6 @@ def __init__(self, observation_space: Space, action_space: Space, config: dict): self.optimizers[name] = optimizer # === Policy attributes === - self.dist_class = WrapModuleDist self.framework = "torch" # Needed to create exploration self.exploration = self._create_exploration() @@ -142,7 +142,9 @@ def compute_actions( self.convert_to_tensor([1]), ) + # pylint:disable=not-callable action_dist = self.dist_class(dist_inputs, self.module) + # pylint:enable=not-callable actions, logp = self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore ) @@ -243,7 +245,9 @@ def compute_log_likelihoods( state_batches, self.convert_to_tensor([1]), ) + # pylint:disable=not-callable action_dist = self.dist_class(dist_inputs, self.module) + # pylint:enable=not-callable log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) return log_likelihoods diff --git a/raylab/utils/exploration/gaussian_noise.py b/raylab/utils/exploration/gaussian_noise.py index b57e9fe1..58eaa86a 100644 --- a/raylab/utils/exploration/gaussian_noise.py +++ b/raylab/utils/exploration/gaussian_noise.py @@ -1,14 +1,13 @@ # pylint:disable=missing-module-docstring -import torch from ray.rllib.utils import override -from raylab.pytorch.nn.distributions.flows import TanhSquashTransform - from .random_uniform import RandomUniform class GaussianNoise(RandomUniform): - """Adds fixed additive gaussian exploration noise to actions before squashing. + """Adds fixed additive gaussian exploration noise to actions. + + Args: noise_stddev (float): Standard deviation of the Gaussian samples. @@ -17,10 +16,6 @@ class GaussianNoise(RandomUniform): def __init__(self, *args, noise_stddev=None, **kwargs): super().__init__(*args, **kwargs) self._noise_stddev = noise_stddev - self._squash = TanhSquashTransform( - low=torch.as_tensor(self.action_space.low), - high=torch.as_tensor(self.action_space.high), - ) @override(RandomUniform) def get_exploration_action(self, *, action_distribution, timestep, explore=True): @@ -31,12 +26,5 @@ def get_exploration_action(self, *, action_distribution, timestep, explore=True) timestep=timestep, explore=explore, ) - return self._get_gaussian_perturbed_actions(action_distribution) + return action_distribution.sample_inject_noise(self._noise_stddev) return action_distribution.deterministic_sample() - - def _get_gaussian_perturbed_actions(self, action_distribution): - module, inputs = action_distribution.model, action_distribution.inputs - actions = module.actor(**inputs) - pre_squash, _ = self._squash(actions, reverse=True) - noise = torch.randn_like(pre_squash) * self._noise_stddev - return self._squash(pre_squash + noise)[0], None From 3c3d1d087fab08ecf2e618c1afa5125e49457103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 07:59:07 -0300 Subject: [PATCH 39/48] feat(policy): add model property MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/policy/torch_policy.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/raylab/policy/torch_policy.py b/raylab/policy/torch_policy.py index e5a5ec88..9ef3cae9 100644 --- a/raylab/policy/torch_policy.py +++ b/raylab/policy/torch_policy.py @@ -68,6 +68,14 @@ def __init__(self, observation_space: Space, action_space: Space, config: dict): self.framework = "torch" # Needed to create exploration self.exploration = self._create_exploration() + @property + def model(self): + """The policy's NN module. + + Mostly for compatibility with RLlib's API. + """ + return self.module + @staticmethod @abstractmethod def get_default_config() -> dict: From 86d6719707653167ac90b0c11071e8304ec1eaa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 08:01:56 -0300 Subject: [PATCH 40/48] chore(policy): add parameter noise compat with new modules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- examples/SOP/cheetah_defaults.py | 6 ++-- examples/SOP/hopper_defaults.py | 6 ++-- examples/SOP/ib_defaults.py | 6 ++-- examples/SOP/swingup_defaults.py | 3 +- examples/SOP/walker_defaults.py | 6 ++-- raylab/agents/sop/policy.py | 5 +-- raylab/policy/modules/actor/deterministic.py | 21 ++++++------ raylab/utils/exploration/gaussian_noise.py | 2 -- raylab/utils/exploration/parameter_noise.py | 32 +++++++++++-------- .../modules/actor/test_deterministic.py | 22 +++++++++---- 10 files changed, 61 insertions(+), 48 deletions(-) diff --git a/examples/SOP/cheetah_defaults.py b/examples/SOP/cheetah_defaults.py index 93e7c63d..5c2a8ba3 100644 --- a/examples/SOP/cheetah_defaults.py +++ b/examples/SOP/cheetah_defaults.py @@ -25,15 +25,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "orthogonal"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ELU", - "initializer_options": {"name": "Orthogonal"}, "layer_norm": True, }, }, @@ -41,7 +42,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ELU", - "initializer_options": {"name": "Orthogonal"}, "delay_action": True, }, }, diff --git a/examples/SOP/hopper_defaults.py b/examples/SOP/hopper_defaults.py index 8840efc8..c27df9d2 100644 --- a/examples/SOP/hopper_defaults.py +++ b/examples/SOP/hopper_defaults.py @@ -27,15 +27,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "orthogonal"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "layer_norm": False, }, }, @@ -44,7 +45,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "delay_action": True, }, }, diff --git a/examples/SOP/ib_defaults.py b/examples/SOP/ib_defaults.py index f6fc72e1..51c6d4e1 100644 --- a/examples/SOP/ib_defaults.py +++ b/examples/SOP/ib_defaults.py @@ -30,15 +30,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "xavier_uniform"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "layer_norm": False, }, }, @@ -47,7 +48,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "delay_action": True, }, }, diff --git a/examples/SOP/swingup_defaults.py b/examples/SOP/swingup_defaults.py index 0c1528d0..afe1e87d 100644 --- a/examples/SOP/swingup_defaults.py +++ b/examples/SOP/swingup_defaults.py @@ -19,8 +19,9 @@ def get_config(): "polyak": 0.995, # === Network === "module": { - "type": "DDPGModule", + "type": "DDPG", "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, diff --git a/examples/SOP/walker_defaults.py b/examples/SOP/walker_defaults.py index 1b8d3e63..8e575342 100644 --- a/examples/SOP/walker_defaults.py +++ b/examples/SOP/walker_defaults.py @@ -42,15 +42,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "xavier_uniform"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.2, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "layer_norm": False, }, }, @@ -59,7 +60,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "delay_action": True, }, }, diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index 97025ccc..70551516 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -44,10 +44,11 @@ def make_module(self, obs_space, action_space, config): module_config.setdefault("critic", {}) module_config["critic"]["double_q"] = config["clipped_double_q"] module_config.setdefault("actor", {}) - module_config["actor"]["perturbed_policy"] = ( + if ( config["exploration_config"]["type"] == "raylab.utils.exploration.ParameterNoise" - ) + ): + module_config["actor"]["parameter_noise"] = True # pylint:disable=no-member return super().make_module(obs_space, action_space, config) diff --git a/raylab/policy/modules/actor/deterministic.py b/raylab/policy/modules/actor/deterministic.py index 41b4a609..662d044a 100644 --- a/raylab/policy/modules/actor/deterministic.py +++ b/raylab/policy/modules/actor/deterministic.py @@ -22,8 +22,9 @@ class DeterministicActorSpec(DataClassJsonMixin): states to pre-action linear features norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't normalize actions before squashing function - behavior: Type of behavior policy. Either 'gaussian', 'parameter_noise', - or 'deterministic' + parameter_noise: Whether to create a separate behavior policy for + parameter noise exploration. It is recommended to enable + encoder.layer_norm alongside this option. smooth_target_policy: Whether to use a noisy target policy for Q-Learning target_gaussian_sigma: Gaussian standard deviation for noisy target @@ -37,7 +38,7 @@ class DeterministicActorSpec(DataClassJsonMixin): encoder: MLPSpec = field(default_factory=MLPSpec) norm_beta: float = 1.2 - behavior: str = "gaussian" + parameter_noise: bool = False smooth_target_policy: bool = True target_gaussian_sigma: float = 0.3 separate_target_policy: bool = False @@ -46,10 +47,6 @@ class DeterministicActorSpec(DataClassJsonMixin): def __post_init__(self): cls_name = type(self).__name__ assert self.norm_beta > 0, f"{cls_name}.norm_beta must be positive" - valid_behaviors = {"gaussian", "parameter_noise", "deterministic"} - assert ( - self.behavior in valid_behaviors - ), f"{cls_name}.behavior must be one of {valid_behaviors}" assert ( self.target_gaussian_sigma > 0 ), f"{cls_name}.target_gaussian_sigma must be positive" @@ -65,7 +62,8 @@ class DeterministicActor(nn.Module): Attributes: policy: The deterministic policy to be learned - behavior: The policy for exploration + behavior: The policy for exploration. `utils.exploration.GaussianNoise` + handles Gaussian action noise exploration separatedly. target_policy: The policy used for estimating the arg max in Q-Learning spec_cls: Expected class of `spec` init argument """ @@ -87,11 +85,12 @@ def make_policy(): policy.initialize_parameters(spec.initializer) behavior = policy - if spec.behavior == "parameter_noise": + if spec.parameter_noise: if not spec.encoder.layer_norm: warnings.warn( - f"Behavior is set to {spec.behavior} but layer normalization is " - "deactivated. Use layer normalization for better stability." + "Behavior policy for parameter noise exploration requested" + " but layer normalization is deactivated. Use layer" + " normalization for better stability." ) behavior = make_policy() behavior.load_state_dict(policy.state_dict()) diff --git a/raylab/utils/exploration/gaussian_noise.py b/raylab/utils/exploration/gaussian_noise.py index 58eaa86a..432812d7 100644 --- a/raylab/utils/exploration/gaussian_noise.py +++ b/raylab/utils/exploration/gaussian_noise.py @@ -7,8 +7,6 @@ class GaussianNoise(RandomUniform): """Adds fixed additive gaussian exploration noise to actions. - - Args: noise_stddev (float): Standard deviation of the Gaussian samples. """ diff --git a/raylab/utils/exploration/parameter_noise.py b/raylab/utils/exploration/parameter_noise.py index d544a004..09a964ff 100644 --- a/raylab/utils/exploration/parameter_noise.py +++ b/raylab/utils/exploration/parameter_noise.py @@ -8,10 +8,8 @@ from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.utils import override from ray.rllib.utils.exploration import Exploration -from ray.rllib.utils.torch_ops import convert_to_non_torch_type from raylab.policy import TorchPolicy -from raylab.pytorch.nn.distributions.flows import TanhSquashTransform from raylab.pytorch.nn.utils import perturb_params from raylab.utils.param_noise import AdaptiveParamNoiseSpec from raylab.utils.param_noise import ddpg_distance_metric @@ -22,18 +20,26 @@ class ParameterNoise(RandomUniform): """Adds adaptive parameter noise exploration schedule to a Policy. + Expects `actor` attribute of `policy.module` to be an instance of + `raylab.policy.modules.actor.policy.deterministic.DeterministicPolicy`. + Args: param_noise_spec: Arguments for `AdaptiveParamNoiseSpec`. """ def __init__(self, *args, param_noise_spec: dict = None, **kwargs): super().__init__(*args, **kwargs) + assert ( + self.model is not None + ), f"Need to pass the model to {type(self).__name__} to check compatibility." + actor, behavior = self.model.actor, self.model.behavior + assert set(actor.parameters()).isdisjoint(set(behavior.parameters())), ( + "Target and behavior policy cannot share parameters in parameter " + "noise exploration." + ) + param_noise_spec = param_noise_spec or {} self._param_noise_spec = AdaptiveParamNoiseSpec(**param_noise_spec) - self._squash = TanhSquashTransform( - low=torch.as_tensor(self.action_space.low), - high=torch.as_tensor(self.action_space.high), - ) @override(RandomUniform) def get_exploration_action( @@ -41,7 +47,7 @@ def get_exploration_action( *, action_distribution: ActionDistribution, timestep: int, - explore: bool = True + explore: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: model, inputs = action_distribution.model, action_distribution.inputs if explore: @@ -61,7 +67,7 @@ def on_episode_start( *, environment: Any = None, episode: Any = None, - tf_sess: Any = None + tf_sess: Any = None, ): # pylint:disable=unused-argument perturb_params( @@ -87,12 +93,10 @@ def update_parameter_noise(self, policy: TorchPolicy, sample_batch: SampleBatch) module = policy.module cur_obs = policy.convert_to_tensor(sample_batch[SampleBatch.CUR_OBS]) actions = policy.convert_to_tensor(sample_batch[SampleBatch.ACTIONS]) - target_actions = module.actor(cur_obs) - unsquashed_acts, _ = self._squash(actions, reverse=True) - unsquashed_targs, _ = self._squash(target_actions, reverse=True) - noisy, target = map( - convert_to_non_torch_type, (unsquashed_acts, unsquashed_targs) - ) + noisy = module.actor.unsquash_action(actions) + target = module.actor.unconstrained_action(cur_obs) + noisy, target = map(lambda x: x.cpu().detach().numpy(), (noisy, target)) + distance = ddpg_distance_metric(noisy, target) self._param_noise_spec.adapt(distance) diff --git a/tests/raylab/policy/modules/actor/test_deterministic.py b/tests/raylab/policy/modules/actor/test_deterministic.py index 33959308..95abee83 100644 --- a/tests/raylab/policy/modules/actor/test_deterministic.py +++ b/tests/raylab/policy/modules/actor/test_deterministic.py @@ -26,15 +26,15 @@ def separate_target_policy(request): return request.param -@pytest.fixture(params="gaussian deterministic parameter_noise".split()) -def behavior(request): +@pytest.fixture(params=(True, False), ids=lambda x: f"ParameterNoise({x})") +def parameter_noise(request): return request.param @pytest.fixture -def spec(module_cls, behavior, separate_target_policy): +def spec(module_cls, parameter_noise, separate_target_policy): return module_cls.spec_cls( - behavior=behavior, separate_target_policy=separate_target_policy + parameter_noise=parameter_noise, separate_target_policy=separate_target_policy ) @@ -54,6 +54,16 @@ def test_module_creation(module): ) +def test_parameter_noise(module_cls, obs_space, action_space): + spec = module_cls.spec_cls(parameter_noise=True) + module = module_cls(obs_space, action_space, spec) + + assert all( + torch.allclose(p, n) + for p, n in zip(module.policy.parameters(), module.behavior.parameters()) + ) + + def test_separate_target_policy(module, spec): policy, target = module.policy, module.target_policy @@ -63,11 +73,11 @@ def test_separate_target_policy(module, spec): assert all(p is t for p, t in zip(policy.parameters(), target.parameters())) -def test_behavior(module, batch, spec): +def test_behavior(module, batch): action = batch[SampleBatch.ACTIONS] samples = module.behavior(batch[SampleBatch.CUR_OBS]) samples_ = module.behavior(batch[SampleBatch.CUR_OBS]) assert samples.shape == action.shape assert samples.dtype == torch.float32 - assert spec.behavior == "gaussian" or torch.allclose(samples, samples_) + assert torch.allclose(samples, samples_) From aa5c0a71557c5b61aaef182bfe3bec2484b15218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 09:41:20 -0300 Subject: [PATCH 41/48] chore: bump version minor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cb4358fc..febc2c49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "raylab" -version = "0.8.5" +version = "0.9.0" description = "Reinforcement learning algorithms in RLlib and PyTorch." authors = ["Ângelo Gregório Lovatto "] license = "MIT" From 72d9c2afce2ec5f57b16cbbbd3ef4f04d65f1c66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 15:01:35 -0300 Subject: [PATCH 42/48] feat(policy): check module compatibility with action dist MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/policy/action_dist.py | 103 +++++++++++++++++++++++++--------- raylab/policy/torch_policy.py | 3 +- 2 files changed, 80 insertions(+), 26 deletions(-) diff --git a/raylab/policy/action_dist.py b/raylab/policy/action_dist.py index d8c99861..cb35e4ef 100644 --- a/raylab/policy/action_dist.py +++ b/raylab/policy/action_dist.py @@ -1,7 +1,10 @@ """Action distribution for compatibility with RLlib's interface.""" +from abc import ABCMeta +from abc import abstractmethod + import torch +import torch.nn as nn from ray.rllib.models.action_dist import ActionDistribution -from ray.rllib.utils import override from torch import Tensor from .modules.actor.policy.deterministic import DeterministicPolicy @@ -9,61 +12,100 @@ from .modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy as V0StochasticPi -class WrapStochasticPolicy(ActionDistribution): +class IncompatibleDistClsError(Exception): + """Exception raised for incompatible action distribution and NN module. + + Args: + dist_cls: Action distribution class + module: NN module + err: AssertionError explaining the reason why distribution and + module are incompatible + + Attributes: + message: Human-readable text explaining what caused the incompatibility + """ + + def __init__(self, dist_cls: type, module: nn.Module, err: Exception): + # pylint:disable=unused-argument + msg = ( + f"Action distribution type {dist_cls} is incompatible" + " with NN module of type {type(module)}. Reason:\n" + " {err}" + ) + super().__init__(msg) + self.message = msg + + +class BaseActionDist(ActionDistribution, metaclass=ABCMeta): + """Base class for TorchPolicy action distributions.""" + + @classmethod + def check_model_compat(cls, model: nn.Module): + """Assert the given NN module is compatible with the distribution. + + Raises: + IncompatibleDistClsError: If `model` is incompatible with the + distribution class + """ + try: + cls._check_model_compat(model) + except AssertionError as err: + raise IncompatibleDistClsError(cls, model, err) + + @classmethod + @abstractmethod + def _check_model_compat(cls, model: nn.Module): + pass + + +class WrapStochasticPolicy(BaseActionDist): """Wraps an nn.Module with a stochastic actor and its inputs. - Expects actor to be a StochasticPolicy instance. + Expects actor to be an instance of StochasticPolicy. """ # pylint:disable=abstract-method + valid_actor_cls = (V0StochasticPi, StochasticPolicy) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - assert hasattr(self.model, "actor") - assert isinstance(self.model.actor, (V0StochasticPi, StochasticPolicy)) self._sampled_logp = None - @override(ActionDistribution) def sample(self): action, logp = self.model.actor.sample(**self.inputs) self._sampled_logp = logp return action, logp - @override(ActionDistribution) def deterministic_sample(self): return self.model.actor.deterministic(**self.inputs) - @override(ActionDistribution) def sampled_action_logp(self): return self._sampled_logp - @override(ActionDistribution) def logp(self, x): return self.model.actor.log_prob(value=x, **self.inputs) - @override(ActionDistribution) def entropy(self): return self.model.actor.entropy(**self.inputs) + @classmethod + def _check_model_compat(cls, model): + assert hasattr(model, "actor"), "NN has no actor attribute." + assert isinstance(model.actor, cls.valid_actor_cls), ( + f"Expected actor to be an instance of {cls.valid_actor_cls};" + " found {type(model.actor)} instead." + ) + -class WrapDeterministicPolicy(ActionDistribution): +class WrapDeterministicPolicy(BaseActionDist): """Wraps an nn.Module with a deterministic actor and its inputs. - Expects actor to be a DeterministicPolicy instance. + Expects actor to be an instance of DeterministicPolicy. """ # pylint:disable=abstract-method + valid_actor_cls = valid_behavior_cls = DeterministicPolicy - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert hasattr(self.model, "actor") and isinstance( - self.model.actor, DeterministicPolicy - ) - assert hasattr(self.model, "behavior") and isinstance( - self.model.behavior, DeterministicPolicy - ) - - @override(ActionDistribution) def sample(self): action = self.model.behavior(**self.inputs) return action, None @@ -74,14 +116,25 @@ def sample_inject_noise(self, noise_stddev: float) -> Tensor: unconstrained_action += torch.randn_like(unconstrained_action) * noise_stddev return self.model.behavior.squash_action(unconstrained_action), None - @override(ActionDistribution) def deterministic_sample(self): return self.model.actor(**self.inputs), None - @override(ActionDistribution) def sampled_action_logp(self): return None - @override(ActionDistribution) def logp(self, x): return None + + @classmethod + def _check_model_compat(cls, model: nn.Module): + assert hasattr(model, "actor"), "NN has no actor attribute" + assert isinstance(model.actor, cls.valid_actor_cls), ( + f"Expected actor to be an instance of {cls.valid_actor_cls};" + " found {type(model.actor)} instead." + ) + + assert hasattr(model, "behavior"), "NN has no behavior attribute" + assert isinstance(model.actor, cls.valid_behavior_cls), ( + f"Expected behavior to be an instance of {cls.valid_behavior_cls};" + " found {type(model.behavior)} instead." + ) diff --git a/raylab/policy/torch_policy.py b/raylab/policy/torch_policy.py index 9ef3cae9..3f07f43a 100644 --- a/raylab/policy/torch_policy.py +++ b/raylab/policy/torch_policy.py @@ -34,7 +34,7 @@ class TorchPolicy(Policy): Attributes: dist_class: Action distribution class for computing actions. Must be set - by subclasses. + by subclasses before calling `__init__`. device: Device in which the parameter tensors reside. All input samples will be converted to tensors and moved to this device module: The policy's neural network module. Should be compilable to @@ -65,6 +65,7 @@ def __init__(self, observation_space: Space, action_space: Space, config: dict): self.optimizers[name] = optimizer # === Policy attributes === + self.dist_class.check_model_compat(self.module) self.framework = "torch" # Needed to create exploration self.exploration = self._create_exploration() From f3bb7a58e0b79107a05d9697409e4f1e34ee93ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 15:40:30 -0300 Subject: [PATCH 43/48] feat(policy): check model, action dist, and exploration compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/policy/action_dist.py | 12 +---- raylab/utils/exploration/base.py | 54 +++++++++++++++++++++ raylab/utils/exploration/gaussian_noise.py | 45 +++++++++++++++-- raylab/utils/exploration/parameter_noise.py | 32 ++++++------ raylab/utils/exploration/random_uniform.py | 12 +++-- 5 files changed, 119 insertions(+), 36 deletions(-) create mode 100644 raylab/utils/exploration/base.py diff --git a/raylab/policy/action_dist.py b/raylab/policy/action_dist.py index cb35e4ef..3a8434be 100644 --- a/raylab/policy/action_dist.py +++ b/raylab/policy/action_dist.py @@ -2,10 +2,8 @@ from abc import ABCMeta from abc import abstractmethod -import torch import torch.nn as nn from ray.rllib.models.action_dist import ActionDistribution -from torch import Tensor from .modules.actor.policy.deterministic import DeterministicPolicy from .modules.actor.policy.stochastic import StochasticPolicy @@ -90,7 +88,7 @@ def entropy(self): @classmethod def _check_model_compat(cls, model): - assert hasattr(model, "actor"), "NN has no actor attribute." + assert hasattr(model, "actor"), f"NN model {type(model)} has no actor attribute" assert isinstance(model.actor, cls.valid_actor_cls), ( f"Expected actor to be an instance of {cls.valid_actor_cls};" " found {type(model.actor)} instead." @@ -110,12 +108,6 @@ def sample(self): action = self.model.behavior(**self.inputs) return action, None - def sample_inject_noise(self, noise_stddev: float) -> Tensor: - """Add zero-mean Gaussian noise to the actions prior to normalizing them.""" - unconstrained_action = self.model.behavior.unconstrained_action(**self.inputs) - unconstrained_action += torch.randn_like(unconstrained_action) * noise_stddev - return self.model.behavior.squash_action(unconstrained_action), None - def deterministic_sample(self): return self.model.actor(**self.inputs), None @@ -127,7 +119,7 @@ def logp(self, x): @classmethod def _check_model_compat(cls, model: nn.Module): - assert hasattr(model, "actor"), "NN has no actor attribute" + assert hasattr(model, "actor"), f"NN model {type(model)} has no actor attribute" assert isinstance(model.actor, cls.valid_actor_cls), ( f"Expected actor to be an instance of {cls.valid_actor_cls};" " found {type(model.actor)} instead." diff --git a/raylab/utils/exploration/base.py b/raylab/utils/exploration/base.py new file mode 100644 index 00000000..320d45d5 --- /dev/null +++ b/raylab/utils/exploration/base.py @@ -0,0 +1,54 @@ +"""Base implementations for all exploration strategies.""" +from abc import ABCMeta +from abc import abstractmethod +from typing import Optional + +import torch.nn as nn +from ray.rllib.utils.exploration import Exploration + +Model = Optional[nn.Module] + + +class IncompatibleExplorationError(Exception): + """Exception raised for incompatible exploration and NN module. + + Args: + exp_cls: Exploration class + module: NN module + err: AssertionError explaining the reason why exploration and module are + incompatible + + Attributes: + message: Human-readable text explaining what caused the incompatibility + """ + + def __init__(self, exp_cls: type, module: Model, err: Exception): + # pylint:disable=unused-argument + msg = ( + f"Exploration type {exp_cls} is incompatible with NN module of type" + " {type(module)}. Reason:\n" + " {err}" + ) + super().__init__(msg) + self.message = msg + + +class BaseExploration(Exploration, metaclass=ABCMeta): + """Base class for exploration objects.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + self.check_model_compat(self.model) + except AssertionError as err: + raise IncompatibleExplorationError(type(self), self.model, err) + + @classmethod + @abstractmethod + def check_model_compat(cls, model: Model): + """Assert the given NN module is compatible with the exploration. + + Raises: + IncompatibleDistClsError: If `model` is incompatible with the + exploration class + """ diff --git a/raylab/utils/exploration/gaussian_noise.py b/raylab/utils/exploration/gaussian_noise.py index 432812d7..8280cfb0 100644 --- a/raylab/utils/exploration/gaussian_noise.py +++ b/raylab/utils/exploration/gaussian_noise.py @@ -1,6 +1,13 @@ # pylint:disable=missing-module-docstring -from ray.rllib.utils import override +from typing import Tuple +import torch +from torch import Tensor + +from raylab.policy.action_dist import BaseActionDist +from raylab.policy.modules.actor.policy.deterministic import DeterministicPolicy + +from .base import Model from .random_uniform import RandomUniform @@ -11,12 +18,19 @@ class GaussianNoise(RandomUniform): noise_stddev (float): Standard deviation of the Gaussian samples. """ + valid_behavior_cls = DeterministicPolicy + def __init__(self, *args, noise_stddev=None, **kwargs): super().__init__(*args, **kwargs) self._noise_stddev = noise_stddev - @override(RandomUniform) - def get_exploration_action(self, *, action_distribution, timestep, explore=True): + def get_exploration_action( + self, + *, + action_distribution: BaseActionDist, + timestep: int, + explore: bool = True, + ): if explore: if timestep < self._pure_exploration_steps: return super().get_exploration_action( @@ -24,5 +38,28 @@ def get_exploration_action(self, *, action_distribution, timestep, explore=True) timestep=timestep, explore=explore, ) - return action_distribution.sample_inject_noise(self._noise_stddev) + return self._inject_gaussian_noise(action_distribution) return action_distribution.deterministic_sample() + + def _inject_gaussian_noise( + self, action_distribution: BaseActionDist + ) -> Tuple[Tensor, None]: + model, inputs = action_distribution.model, action_distribution.inputs + unconstrained_action = model.behavior.unconstrained_action(**inputs) + unconstrained_action += ( + torch.randn_like(unconstrained_action) * self._noise_stddev + ) + action = model.behavior.squash_action(unconstrained_action) + return action, None + + @classmethod + def check_model_compat(cls, model: Model): + RandomUniform.check_model_compat(model) + assert model is not None, f"{cls} exploration needs access to the NN." + assert hasattr( + model, "behavior" + ), f"NN model {type(model)} has no behavior attribute." + assert isinstance(model.behavior, cls.valid_behavior_cls), ( + f"Expected behavior to be an instance of {cls.valid_behavior_cls};" + " found {type(model.behavior)} instead." + ) diff --git a/raylab/utils/exploration/parameter_noise.py b/raylab/utils/exploration/parameter_noise.py index 09a964ff..4fe200cc 100644 --- a/raylab/utils/exploration/parameter_noise.py +++ b/raylab/utils/exploration/parameter_noise.py @@ -6,14 +6,13 @@ import torch from ray.rllib import SampleBatch from ray.rllib.models.action_dist import ActionDistribution -from ray.rllib.utils import override -from ray.rllib.utils.exploration import Exploration from raylab.policy import TorchPolicy from raylab.pytorch.nn.utils import perturb_params from raylab.utils.param_noise import AdaptiveParamNoiseSpec from raylab.utils.param_noise import ddpg_distance_metric +from .base import Model from .random_uniform import RandomUniform @@ -29,19 +28,9 @@ class ParameterNoise(RandomUniform): def __init__(self, *args, param_noise_spec: dict = None, **kwargs): super().__init__(*args, **kwargs) - assert ( - self.model is not None - ), f"Need to pass the model to {type(self).__name__} to check compatibility." - actor, behavior = self.model.actor, self.model.behavior - assert set(actor.parameters()).isdisjoint(set(behavior.parameters())), ( - "Target and behavior policy cannot share parameters in parameter " - "noise exploration." - ) - param_noise_spec = param_noise_spec or {} self._param_noise_spec = AdaptiveParamNoiseSpec(**param_noise_spec) - @override(RandomUniform) def get_exploration_action( self, *, @@ -49,7 +38,6 @@ def get_exploration_action( timestep: int, explore: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - model, inputs = action_distribution.model, action_distribution.inputs if explore: if timestep < self._pure_exploration_steps: return super().get_exploration_action( @@ -57,10 +45,9 @@ def get_exploration_action( timestep=timestep, explore=explore, ) - return model.behavior(**inputs), None - return model.actor(**inputs), None + return action_distribution.sample() + return action_distribution.deterministic_sample() - @override(Exploration) def on_episode_start( self, policy: TorchPolicy, @@ -77,14 +64,12 @@ def on_episode_start( ) @torch.no_grad() - @override(Exploration) def postprocess_trajectory( self, policy: TorchPolicy, sample_batch: SampleBatch, tf_sess: Any = None ): self.update_parameter_noise(policy, sample_batch) return sample_batch - @override(Exploration) def get_info(self) -> dict: return {"param_noise_stddev": self._param_noise_spec.curr_stddev} @@ -100,3 +85,14 @@ def update_parameter_noise(self, policy: TorchPolicy, sample_batch: SampleBatch) distance = ddpg_distance_metric(noisy, target) self._param_noise_spec.adapt(distance) + + @classmethod + def check_model_compat(cls, model: Model): + assert ( + model is not None + ), f"Need to pass the model to {cls} to check compatibility." + actor, behavior = model.actor, model.behavior + assert set(actor.parameters()).isdisjoint(set(behavior.parameters())), ( + "Target and behavior policy cannot share parameters in parameter " + "noise exploration." + ) diff --git a/raylab/utils/exploration/random_uniform.py b/raylab/utils/exploration/random_uniform.py index 10014476..2ce3192c 100644 --- a/raylab/utils/exploration/random_uniform.py +++ b/raylab/utils/exploration/random_uniform.py @@ -1,12 +1,13 @@ # pylint:disable=missing-module-docstring import numpy as np -from ray.rllib.utils import override -from ray.rllib.utils.exploration import Exploration import raylab.pytorch.utils as ptu +from .base import BaseExploration +from .base import Model -class RandomUniform(Exploration): + +class RandomUniform(BaseExploration): """Samples actions from the Gym action space Args: @@ -28,7 +29,6 @@ def __init__(self, *args, pure_exploration_steps=0, **kwargs): ) self._pure_exploration_steps = pure_exploration_steps - @override(Exploration) def get_exploration_action(self, *, action_distribution, timestep, explore=True): # pylint:disable=unused-argument if explore: @@ -43,3 +43,7 @@ def get_exploration_action(self, *, action_distribution, timestep, explore=True) ) return acts, logp return action_distribution.deterministic_sample() + + @classmethod + def check_model_compat(cls, model: Model): + pass From 4e4e3a0dade488bbdfb30199b30f136980915f05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 15:57:37 -0300 Subject: [PATCH 44/48] chore(policy): allow subclasses to set `dist_class` before calling init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/acktr/policy.py | 5 +---- raylab/agents/mage/policy.py | 2 +- raylab/agents/mapo/policy.py | 3 +-- raylab/agents/mbpo/policy.py | 3 +-- raylab/agents/naf/policy.py | 3 +-- raylab/agents/sac/policy.py | 3 +-- raylab/agents/sop/policy.py | 3 +-- raylab/agents/svg/policy.py | 4 ++-- raylab/agents/trpo/policy.py | 5 +---- raylab/policy/torch_policy.py | 3 +++ 10 files changed, 13 insertions(+), 21 deletions(-) diff --git a/raylab/agents/acktr/policy.py b/raylab/agents/acktr/policy.py index 4bac4f21..2900c615 100644 --- a/raylab/agents/acktr/policy.py +++ b/raylab/agents/acktr/policy.py @@ -52,10 +52,7 @@ class ACKTRTorchPolicy(TorchPolicy): """Policy class for Actor-Critic with Kronecker factored Trust Region.""" # pylint:disable=abstract-method - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dist_class = WrapStochasticPolicy + dist_class = WrapStochasticPolicy @staticmethod @override(TorchPolicy) diff --git a/raylab/agents/mage/policy.py b/raylab/agents/mage/policy.py index 3692a0c3..f28387a3 100644 --- a/raylab/agents/mage/policy.py +++ b/raylab/agents/mage/policy.py @@ -19,10 +19,10 @@ class MAGETorchPolicy(ModelTrainingMixin, EnvFnMixin, SOPTorchPolicy): """ # pylint: disable=abstract-method + dist_class = WrapDeterministicPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) - self.dist_class = WrapDeterministicPolicy module = self.module self.loss_model = ModelEnsembleMLE(module.models) diff --git a/raylab/agents/mapo/policy.py b/raylab/agents/mapo/policy.py index 77bcbf53..79efeb21 100644 --- a/raylab/agents/mapo/policy.py +++ b/raylab/agents/mapo/policy.py @@ -15,11 +15,10 @@ class MAPOTorchPolicy(ModelTrainingMixin, EnvFnMixin, SACTorchPolicy): """Model-Aware Policy Optimization policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapStochasticPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) - self.dist_class = WrapStochasticPolicy - self.loss_model = SPAML( self.module.models, self.module.actor, self.module.critics ) diff --git a/raylab/agents/mbpo/policy.py b/raylab/agents/mbpo/policy.py index 6b61e9bd..f93747ef 100644 --- a/raylab/agents/mbpo/policy.py +++ b/raylab/agents/mbpo/policy.py @@ -16,11 +16,10 @@ class MBPOTorchPolicy( """Model-Based Policy Optimization policy in PyTorch to use with RLlib.""" # pylint:disable=abstract-method,too-many-ancestors + dist_class = WrapStochasticPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) - self.dist_class = WrapStochasticPolicy - models = self.module.models self.loss_model = ModelEnsembleMLE(models) diff --git a/raylab/agents/naf/policy.py b/raylab/agents/naf/policy.py index dee40f6f..174d03fe 100644 --- a/raylab/agents/naf/policy.py +++ b/raylab/agents/naf/policy.py @@ -14,11 +14,10 @@ class NAFTorchPolicy(TargetNetworksMixin, TorchPolicy): """Normalized Advantage Function policy in Pytorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapDeterministicPolicy def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.dist_class = WrapDeterministicPolicy - target_critics = [lambda s, _, v=v: v(s) for v in self.module.target_vcritics] self.loss_fn = ClippedDoubleQLearning( self.module.critics, target_critics, actor=lambda _: None, diff --git a/raylab/agents/sac/policy.py b/raylab/agents/sac/policy.py index 89545529..7bc4b611 100644 --- a/raylab/agents/sac/policy.py +++ b/raylab/agents/sac/policy.py @@ -16,11 +16,10 @@ class SACTorchPolicy(TargetNetworksMixin, TorchPolicy): """Soft Actor-Critic policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapStochasticPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) - self.dist_class = WrapStochasticPolicy - self.loss_actor = ReparameterizedSoftPG(self.module.actor, self.module.critics) self.loss_critic = SoftCDQLearning( self.module.critics, self.module.target_critics, self.module.actor.sample diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index 70551516..94b29ef5 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -15,11 +15,10 @@ class SOPTorchPolicy(TargetNetworksMixin, TorchPolicy): """Streamlined Off-Policy policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapDeterministicPolicy def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.dist_class = WrapDeterministicPolicy - self.loss_actor = DeterministicPolicyGradient( self.module.actor, self.module.critics, ) diff --git a/raylab/agents/svg/policy.py b/raylab/agents/svg/policy.py index 7dcc4a52..d35f63d9 100644 --- a/raylab/agents/svg/policy.py +++ b/raylab/agents/svg/policy.py @@ -14,10 +14,10 @@ class SVGTorchPolicy(EnvFnMixin, TargetNetworksMixin, TorchPolicy): """Stochastic Value Gradients policy using PyTorch.""" # pylint: disable=abstract-method + dist_class = WrapStochasticPolicy + def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) - self.dist_class = WrapStochasticPolicy - self.loss_model = MaximumLikelihood(self.module.model) self.loss_critic = ISFittedVIteration( self.module.critic, self.module.target_critic diff --git a/raylab/agents/trpo/policy.py b/raylab/agents/trpo/policy.py index 702bd6db..b9ea9904 100644 --- a/raylab/agents/trpo/policy.py +++ b/raylab/agents/trpo/policy.py @@ -25,10 +25,7 @@ class TRPOTorchPolicy(TorchPolicy): """Policy class for Trust Region Policy Optimization.""" # pylint:disable=abstract-method - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dist_class = WrapStochasticPolicy + dist_class = WrapStochasticPolicy @staticmethod @override(TorchPolicy) diff --git a/raylab/policy/torch_policy.py b/raylab/policy/torch_policy.py index 3f07f43a..a866d809 100644 --- a/raylab/policy/torch_policy.py +++ b/raylab/policy/torch_policy.py @@ -54,6 +54,8 @@ def __init__(self, observation_space: Space, action_space: Space, config: dict): whitelist=Trainer._allow_unknown_subkeys, override_all_if_type_changes=Trainer._override_all_subkeys_if_type_changes, ) + # Allow subclasses to set `dist_class` before calling init + action_dist = getattr(self, "dist_class", None) super().__init__(observation_space, action_space, config) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -65,6 +67,7 @@ def __init__(self, observation_space: Space, action_space: Space, config: dict): self.optimizers[name] = optimizer # === Policy attributes === + self.dist_class = action_dist self.dist_class.check_model_compat(self.module) self.framework = "torch" # Needed to create exploration self.exploration = self._create_exploration() From 8778647a65de597b470d7cad524173e67d57ec58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 16:31:31 -0300 Subject: [PATCH 45/48] test(policy): use dummy action dist MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- tests/raylab/policy/model_based/conftest.py | 28 +++++++++++++ .../policy/model_based/test_envfn_mixin.py | 26 ++++++------ .../policy/model_based/test_sampling_mixin.py | 40 +++++++++--------- .../policy/model_based/test_training_mixin.py | 42 +++++++++---------- 4 files changed, 82 insertions(+), 54 deletions(-) create mode 100644 tests/raylab/policy/model_based/conftest.py diff --git a/tests/raylab/policy/model_based/conftest.py b/tests/raylab/policy/model_based/conftest.py new file mode 100644 index 00000000..209986d9 --- /dev/null +++ b/tests/raylab/policy/model_based/conftest.py @@ -0,0 +1,28 @@ +# pylint:disable=missing-docstring,redefined-outer-name,protected-access +import pytest + +from raylab.policy import TorchPolicy +from raylab.policy.action_dist import BaseActionDist + + +@pytest.fixture(scope="module") +def action_dist(): + class ActionDist(BaseActionDist): + # pylint:disable=abstract-method + @classmethod + def _check_model_compat(cls, *args, **kwargs): + pass + + return ActionDist + + +@pytest.fixture(scope="module") +def base_policy_cls(action_dist, obs_space, action_space): + class Policy(TorchPolicy): + # pylint:disable=abstract-method + dist_class = action_dist + + def __init__(self, config): + super().__init__(obs_space, action_space, config) + + return Policy diff --git a/tests/raylab/policy/model_based/test_envfn_mixin.py b/tests/raylab/policy/model_based/test_envfn_mixin.py index e11b374b..6e0dde75 100644 --- a/tests/raylab/policy/model_based/test_envfn_mixin.py +++ b/tests/raylab/policy/model_based/test_envfn_mixin.py @@ -5,22 +5,24 @@ import torch from raylab.envs import get_reward_fn -from raylab.envs import get_termination_fn from raylab.policy import EnvFnMixin -from raylab.policy import TorchPolicy from raylab.utils.debug import fake_space_samples -class DummyPolicy(EnvFnMixin, TorchPolicy): - # pylint:disable=all - @staticmethod - def get_default_config(): - return {"module": {"type": "OnPolicyActorCritic"}} +@pytest.fixture +def policy_cls(base_policy_cls): + class Policy(EnvFnMixin, base_policy_cls): + @staticmethod + def get_default_config(): + return {"module": {"type": "OnPolicyActorCritic"}} + + return Policy @pytest.fixture def reward_fn(): - def func(obs, act, new_obs): + def func(*args): + act = args[1] return act.norm(p=1, dim=-1) return func @@ -28,7 +30,7 @@ def func(obs, act, new_obs): @pytest.fixture def termination_fn(): - def func(obs, act, new_obs): + def func(obs, *_): return torch.randn(obs.shape[:-1]) > 0 return func @@ -36,7 +38,7 @@ def func(obs, act, new_obs): @pytest.fixture def dynamics_fn(): - def func(obs, act): + def func(obs, _): sample = torch.randn_like(obs) log_prob = torch.sum( -(sample ** 2) / 2 @@ -50,8 +52,8 @@ def func(obs, act): @pytest.fixture -def policy(obs_space, action_space): - return DummyPolicy(obs_space, action_space, {}) +def policy(policy_cls): + return policy_cls({}) def test_init(policy): diff --git a/tests/raylab/policy/model_based/test_sampling_mixin.py b/tests/raylab/policy/model_based/test_sampling_mixin.py index 0b535603..31479549 100644 --- a/tests/raylab/policy/model_based/test_sampling_mixin.py +++ b/tests/raylab/policy/model_based/test_sampling_mixin.py @@ -15,32 +15,32 @@ ROLLOUT_SCHEDULE = ([(0, 1), (200, 10)], [(7, 2)]) -class DummyPolicy(ModelSamplingMixin, TorchPolicy): - # pylint:disable=all - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +@pytest.fixture(scope="module") +def policy_cls(base_policy_cls): + class Policy(ModelSamplingMixin, base_policy_cls): + # pylint:disable=all - def reward_fn(obs, act, new_obs): - return act.norm(p=1, dim=-1) + def __init__(self, config): + super().__init__(config) - def termination_fn(obs, act, new_obs): - return torch.randn(obs.shape[:-1]) > 0 + def reward_fn(obs, act, new_obs): + return act.norm(p=1, dim=-1) - self.reward_fn = reward_fn - self.termination_fn = termination_fn + def termination_fn(obs, act, new_obs): + return torch.randn(obs.shape[:-1]) > 0 - @staticmethod - def get_default_config(): - return { - "model_sampling": ModelSamplingMixin.model_sampling_defaults(), - "module": {"type": "ModelBasedSAC"}, - "seed": None, - } + self.reward_fn = reward_fn + self.termination_fn = termination_fn + @staticmethod + def get_default_config(): + return { + "model_sampling": ModelSamplingMixin.model_sampling_defaults(), + "module": {"type": "ModelBasedSAC"}, + "seed": None, + } -@pytest.fixture(scope="module") -def policy_cls(obs_space, action_space): - return functools.partial(DummyPolicy, obs_space, action_space) + return Policy @pytest.fixture( diff --git a/tests/raylab/policy/model_based/test_training_mixin.py b/tests/raylab/policy/model_based/test_training_mixin.py index b2310960..7a1d4d4f 100644 --- a/tests/raylab/policy/model_based/test_training_mixin.py +++ b/tests/raylab/policy/model_based/test_training_mixin.py @@ -8,7 +8,6 @@ from raylab.policy import ModelTrainingMixin from raylab.policy import OptimizerCollection -from raylab.policy import TorchPolicy from raylab.policy.model_based.training_mixin import Evaluator from raylab.pytorch.optim import build_optimizer from raylab.utils.debug import fake_batch @@ -24,25 +23,6 @@ def __call__(self, _): return losses, {"loss(models)": losses.mean().item()} -class DummyPolicy(ModelTrainingMixin, TorchPolicy): - # pylint:disable=abstract-method - def __init__(self, observation_space, action_space, config): - super().__init__(observation_space, action_space, config) - loss = DummyLoss() - loss.ensemble_size = len(self.module.models) - self.loss_model = loss - - @staticmethod - def get_default_config(): - return { - "model_training": ModelTrainingMixin.model_training_defaults(), - "module": {"type": "ModelBasedSAC"}, - } - - def make_optimizers(self): - return {"models": build_optimizer(self.module.models, {"type": "Adam"})} - - @pytest.fixture def train_samples(obs_space, action_space): return fake_batch(obs_space, action_space, batch_size=80) @@ -54,8 +34,26 @@ def eval_samples(obs_space, action_space): @pytest.fixture(scope="module") -def policy_cls(obs_space, action_space): - return functools.partial(DummyPolicy, obs_space, action_space) +def policy_cls(base_policy_cls): + class Policy(ModelTrainingMixin, base_policy_cls): + # pylint:disable=abstract-method + def __init__(self, config): + super().__init__(config) + loss = DummyLoss() + loss.ensemble_size = len(self.module.models) + self.loss_model = loss + + @staticmethod + def get_default_config(): + return { + "model_training": ModelTrainingMixin.model_training_defaults(), + "module": {"type": "ModelBasedSAC"}, + } + + def make_optimizers(self): + return {"models": build_optimizer(self.module.models, {"type": "Adam"})} + + return Policy @pytest.fixture(scope="module", params=(1, 4), ids=lambda s: f"Ensemble({s})") From 24de650694b320da7d057f48ae4f5d5a8db20e16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 17:11:40 -0300 Subject: [PATCH 46/48] chore(modules): rename parameter_noise option to separate_behavior MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/sop/trainer.py | 2 +- raylab/policy/modules/actor/deterministic.py | 16 ++++++++-------- .../policy/modules/actor/test_deterministic.py | 13 +++++++------ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/raylab/agents/sop/trainer.py b/raylab/agents/sop/trainer.py index 006cc632..43ed20ca 100644 --- a/raylab/agents/sop/trainer.py +++ b/raylab/agents/sop/trainer.py @@ -20,7 +20,7 @@ "polyak": 0.995, # Update policy every this number of calls to `learn_on_batch` "policy_delay": 1, - "module": {"type": "DDPG"}, + "module": {"type": "DDPG", "actor": {"separate_behavior": True}}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/policy/modules/actor/deterministic.py b/raylab/policy/modules/actor/deterministic.py index 662d044a..d109c04a 100644 --- a/raylab/policy/modules/actor/deterministic.py +++ b/raylab/policy/modules/actor/deterministic.py @@ -22,9 +22,9 @@ class DeterministicActorSpec(DataClassJsonMixin): states to pre-action linear features norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't normalize actions before squashing function - parameter_noise: Whether to create a separate behavior policy for - parameter noise exploration. It is recommended to enable - encoder.layer_norm alongside this option. + separate_behavior: Whether to create a separate behavior policy. Usually + for parameter noise exploration, in which case it is recommended to + enable encoder.layer_norm alongside this option. smooth_target_policy: Whether to use a noisy target policy for Q-Learning target_gaussian_sigma: Gaussian standard deviation for noisy target @@ -38,7 +38,7 @@ class DeterministicActorSpec(DataClassJsonMixin): encoder: MLPSpec = field(default_factory=MLPSpec) norm_beta: float = 1.2 - parameter_noise: bool = False + separate_behavior: bool = False smooth_target_policy: bool = True target_gaussian_sigma: float = 0.3 separate_target_policy: bool = False @@ -85,12 +85,12 @@ def make_policy(): policy.initialize_parameters(spec.initializer) behavior = policy - if spec.parameter_noise: + if spec.separate_behavior: if not spec.encoder.layer_norm: warnings.warn( - "Behavior policy for parameter noise exploration requested" - " but layer normalization is deactivated. Use layer" - " normalization for better stability." + "Separate behavior policy requested and layer normalization" + " deactivated. If using parameter noise exploration, enable" + " layer normalization for better stability." ) behavior = make_policy() behavior.load_state_dict(policy.state_dict()) diff --git a/tests/raylab/policy/modules/actor/test_deterministic.py b/tests/raylab/policy/modules/actor/test_deterministic.py index 95abee83..eae71899 100644 --- a/tests/raylab/policy/modules/actor/test_deterministic.py +++ b/tests/raylab/policy/modules/actor/test_deterministic.py @@ -26,15 +26,16 @@ def separate_target_policy(request): return request.param -@pytest.fixture(params=(True, False), ids=lambda x: f"ParameterNoise({x})") -def parameter_noise(request): +@pytest.fixture(params=(True, False), ids=lambda x: f"SeparateBehavior({x})") +def separate_behavior(request): return request.param @pytest.fixture -def spec(module_cls, parameter_noise, separate_target_policy): +def spec(module_cls, separate_behavior, separate_target_policy): return module_cls.spec_cls( - parameter_noise=parameter_noise, separate_target_policy=separate_target_policy + separate_behavior=separate_behavior, + separate_target_policy=separate_target_policy, ) @@ -54,8 +55,8 @@ def test_module_creation(module): ) -def test_parameter_noise(module_cls, obs_space, action_space): - spec = module_cls.spec_cls(parameter_noise=True) +def test_separate_behavior(module_cls, obs_space, action_space): + spec = module_cls.spec_cls(separate_behavior=True) module = module_cls(obs_space, action_space, spec) assert all( From 8d99815972edcb16fc16accbdb6d977698535b85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 17:32:18 -0300 Subject: [PATCH 47/48] refactor(policy): remove TargetNetworksMixin MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Ângelo Lovatto --- raylab/agents/naf/policy.py | 8 ++++--- raylab/agents/sac/policy.py | 8 ++++--- raylab/agents/sop/policy.py | 8 ++++--- raylab/agents/svg/inf/policy.py | 2 +- raylab/agents/svg/one/policy.py | 2 +- raylab/agents/svg/policy.py | 9 ++++++-- raylab/agents/svg/soft/policy.py | 2 +- raylab/policy/__init__.py | 2 -- raylab/policy/target_networks_mixin.py | 22 ------------------- tests/raylab/agents/sac/test_critics.py | 8 +------ tests/raylab/agents/sop/test_policy.py | 8 +------ .../raylab/agents/svg/test_value_function.py | 13 ----------- 12 files changed, 27 insertions(+), 65 deletions(-) delete mode 100644 raylab/policy/target_networks_mixin.py diff --git a/raylab/agents/naf/policy.py b/raylab/agents/naf/policy.py index 174d03fe..daebe840 100644 --- a/raylab/agents/naf/policy.py +++ b/raylab/agents/naf/policy.py @@ -3,14 +3,14 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy from raylab.policy.action_dist import WrapDeterministicPolicy from raylab.policy.losses import ClippedDoubleQLearning +from raylab.pytorch.nn.utils import update_polyak from raylab.pytorch.optim import build_optimizer -class NAFTorchPolicy(TargetNetworksMixin, TorchPolicy): +class NAFTorchPolicy(TorchPolicy): """Normalized Advantage Function policy in Pytorch to use with RLlib.""" # pylint: disable=abstract-method @@ -59,7 +59,9 @@ def learn_on_batch(self, samples): loss.backward() info.update(self.extra_grad_info()) - self.update_targets("vcritics", "target_vcritics") + update_polyak( + self.module.vcritics, self.module.target_vcritics, self.config["polyak"] + ) return info @torch.no_grad() diff --git a/raylab/agents/sac/policy.py b/raylab/agents/sac/policy.py index 7bc4b611..dd17824f 100644 --- a/raylab/agents/sac/policy.py +++ b/raylab/agents/sac/policy.py @@ -3,16 +3,16 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy from raylab.policy.action_dist import WrapStochasticPolicy from raylab.policy.losses import MaximumEntropyDual from raylab.policy.losses import ReparameterizedSoftPG from raylab.policy.losses import SoftCDQLearning +from raylab.pytorch.nn.utils import update_polyak from raylab.pytorch.optim import build_optimizer -class SACTorchPolicy(TargetNetworksMixin, TorchPolicy): +class SACTorchPolicy(TorchPolicy): """Soft Actor-Critic policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method @@ -75,7 +75,9 @@ def learn_on_batch(self, samples): if self.config["target_entropy"] is not None: info.update(self._update_alpha(batch_tensors)) - self.update_targets("critics", "target_critics") + update_polyak( + self.module.critics, self.module.target_critics, self.config["polyak"] + ) return info def _update_critic(self, batch_tensors): diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index 94b29ef5..bae9d1ca 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -3,15 +3,15 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy from raylab.policy.action_dist import WrapDeterministicPolicy from raylab.policy.losses import ClippedDoubleQLearning from raylab.policy.losses import DeterministicPolicyGradient +from raylab.pytorch.nn.utils import update_polyak from raylab.pytorch.optim import build_optimizer -class SOPTorchPolicy(TargetNetworksMixin, TorchPolicy): +class SOPTorchPolicy(TorchPolicy): """Streamlined Off-Policy policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method @@ -71,7 +71,9 @@ def learn_on_batch(self, samples): if self._grad_step % self.config["policy_delay"] == 0: info.update(self._update_policy(batch_tensors)) - self.update_targets("critics", "target_critics") + update_polyak( + self.module.critics, self.module.target_critics, self.config["polyak"] + ) return info def _update_critic(self, batch_tensors): diff --git a/raylab/agents/svg/inf/policy.py b/raylab/agents/svg/inf/policy.py index 0ec937f2..0c0c6075 100644 --- a/raylab/agents/svg/inf/policy.py +++ b/raylab/agents/svg/inf/policy.py @@ -88,7 +88,7 @@ def _learn_off_policy(self, batch_tensors): info.update(_info) loss.backward() - self.update_targets("critic", "target_critic") + self._update_polyak() return info def _learn_on_policy(self, batch_tensors, samples): diff --git a/raylab/agents/svg/one/policy.py b/raylab/agents/svg/one/policy.py index c95c8a41..783cc7cd 100644 --- a/raylab/agents/svg/one/policy.py +++ b/raylab/agents/svg/one/policy.py @@ -95,7 +95,7 @@ def learn_on_batch(self, samples): info.update(self.extra_grad_info(batch_tensors)) info.update(self.update_kl_coeff(samples)) - self.update_targets("critic", "target_critic") + self._update_polyak() return info @torch.no_grad() diff --git a/raylab/agents/svg/policy.py b/raylab/agents/svg/policy.py index d35f63d9..d2322fc4 100644 --- a/raylab/agents/svg/policy.py +++ b/raylab/agents/svg/policy.py @@ -3,14 +3,14 @@ from ray.rllib import SampleBatch from raylab.policy import EnvFnMixin -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy from raylab.policy.action_dist import WrapStochasticPolicy from raylab.policy.losses import ISFittedVIteration from raylab.policy.losses import MaximumLikelihood +from raylab.pytorch.nn.utils import update_polyak -class SVGTorchPolicy(EnvFnMixin, TargetNetworksMixin, TorchPolicy): +class SVGTorchPolicy(EnvFnMixin, TorchPolicy): """Stochastic Value Gradients policy using PyTorch.""" # pylint: disable=abstract-method @@ -47,3 +47,8 @@ def compute_joint_model_value_loss(self, batch_tensors): loss = mle_loss + self.config["vf_loss_coeff"] * isfv_loss return loss, {**mle_info, **isfv_info} + + def _update_polyak(self): + update_polyak( + self.module.critic, self.module.target_critic, self.config["polyak"] + ) diff --git a/raylab/agents/svg/soft/policy.py b/raylab/agents/svg/soft/policy.py index b17fa56e..6786c6ab 100644 --- a/raylab/agents/svg/soft/policy.py +++ b/raylab/agents/svg/soft/policy.py @@ -109,7 +109,7 @@ def learn_on_batch(self, samples): if self.config["target_entropy"] is not None: info.update(self._update_alpha(batch_tensors)) - self.update_targets("critic", "target_critic") + self._update_polyak() return info def _update_model(self, batch_tensors): diff --git a/raylab/policy/__init__.py b/raylab/policy/__init__.py index 799eb296..3c6178c8 100644 --- a/raylab/policy/__init__.py +++ b/raylab/policy/__init__.py @@ -5,7 +5,6 @@ from .model_based import ModelSamplingMixin from .model_based import ModelTrainingMixin from .optimizer_collection import OptimizerCollection -from .target_networks_mixin import TargetNetworksMixin from .torch_policy import TorchPolicy __all__ = [ @@ -13,7 +12,6 @@ "EnvFnMixin", "ModelSamplingMixin", "ModelTrainingMixin", - "TargetNetworksMixin", "TorchPolicy", "OptimizerCollection", ] diff --git a/raylab/policy/target_networks_mixin.py b/raylab/policy/target_networks_mixin.py deleted file mode 100644 index 4028c76b..00000000 --- a/raylab/policy/target_networks_mixin.py +++ /dev/null @@ -1,22 +0,0 @@ -# pylint: disable=missing-docstring -# pylint: enable=missing-docstring -from raylab.pytorch.nn.utils import update_polyak - - -class TargetNetworksMixin: - """Adds method to update target networks by name.""" - - # pylint: disable=too-few-public-methods - - def update_targets(self, module, target_module): - """Update target networks through one step of polyak averaging. - - Arguments: - module (str): name of primary module in the policy's module dict - target_module (str): name of target module in the policy's module dict - """ - update_polyak( - getattr(self.module, module), - getattr(self.module, target_module), - self.config["polyak"], - ) diff --git a/tests/raylab/agents/sac/test_critics.py b/tests/raylab/agents/sac/test_critics.py index 6c19ef3f..2535dd75 100644 --- a/tests/raylab/agents/sac/test_critics.py +++ b/tests/raylab/agents/sac/test_critics.py @@ -75,14 +75,8 @@ def test_critic_loss(policy_and_batch): ) -def test_target_params_update(policy_and_batch): +def test_target_net_init(policy_and_batch): policy, _ = policy_and_batch params = list(policy.module.critics.parameters()) target_params = list(policy.module.target_critics.parameters()) assert all(torch.allclose(p, q) for p, q in zip(params, target_params)) - - old_params = [p.clone() for p in target_params] - for param in params: - param.data.add_(torch.ones_like(param)) - policy.update_targets("critics", "target_critics") - assert all(not torch.allclose(p, q) for p, q in zip(target_params, old_params)) diff --git a/tests/raylab/agents/sop/test_policy.py b/tests/raylab/agents/sop/test_policy.py index 4dc4118c..0fa414da 100644 --- a/tests/raylab/agents/sop/test_policy.py +++ b/tests/raylab/agents/sop/test_policy.py @@ -21,17 +21,11 @@ def policy(obs_space, action_space, config): return SOPTorchPolicy(obs_space, action_space, config) -def test_target_params_update(policy): +def test_target_critics_init(policy): params = list(policy.module.critics.parameters()) target_params = list(policy.module.target_critics.parameters()) assert all(torch.allclose(p, q) for p, q in zip(params, target_params)) - old_params = [p.clone() for p in target_params] - for param in params: - param.data.add_(torch.ones_like(param)) - policy.update_targets("critics", "target_critics") - assert all(not torch.allclose(p, q) for p, q in zip(target_params, old_params)) - @pytest.fixture def samples(obs_space, action_space): diff --git a/tests/raylab/agents/svg/test_value_function.py b/tests/raylab/agents/svg/test_value_function.py index 7a74bb85..ed0f88a5 100644 --- a/tests/raylab/agents/svg/test_value_function.py +++ b/tests/raylab/agents/svg/test_value_function.py @@ -40,16 +40,3 @@ def test_importance_sampling_weighted_loss(policy_and_batch): assert all(p.grad is None for p in other_params) assert "loss(critic)" in info - - -def test_target_params_update(policy_and_batch): - policy, _ = policy_and_batch - - old_params = [p.clone() for p in policy.module.target_critic.parameters()] - for param in policy.module.critic.parameters(): - param.data.add_(torch.ones_like(param)) - policy.update_targets("critic", "target_critic") - assert all( - not torch.allclose(p, p_) - for p, p_ in zip(policy.module.target_critic.parameters(), old_params) - ) From dd0b9c988e026a139760d044d7ec89dc6f76c6cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=82ngelo=20Lovatto?= Date: Sat, 4 Jul 2020 17:50:32 -0300 Subject: [PATCH 48/48] refactor(policy): remove clipped_double_q top-level config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable Clipped Double Q-Learning by configuring double Q in the NN module. Signed-off-by: Ângelo Lovatto --- raylab/agents/mage/trainer.py | 5 +---- raylab/agents/mapo/trainer.py | 2 +- raylab/agents/mbpo/trainer.py | 6 ++++-- raylab/agents/sac/policy.py | 7 ------- raylab/agents/sac/trainer.py | 5 +---- raylab/agents/sop/policy.py | 14 -------------- raylab/agents/sop/trainer.py | 9 +++++---- tests/raylab/agents/sac/test_actor.py | 21 ++++++++++++++------- tests/raylab/agents/sac/test_critics.py | 6 +++--- tests/raylab/agents/sop/test_policy.py | 6 +++--- 10 files changed, 32 insertions(+), 49 deletions(-) diff --git a/raylab/agents/mage/trainer.py b/raylab/agents/mage/trainer.py index 1761bae3..52b36d7d 100644 --- a/raylab/agents/mage/trainer.py +++ b/raylab/agents/mage/trainer.py @@ -16,9 +16,6 @@ "policy_improvements": 10, "real_data_ratio": 1, # === MAGETorchPolicy === - # Clipped Double Q-Learning: use the minimun of two target Q functions - # as the next action-value in the target for fitted Q iteration - "clipped_double_q": True, # TD error regularization for MAGE loss "lambda": 0.05, # PyTorch optimizers to use @@ -39,7 +36,7 @@ patience_epochs=None, improvement_threshold=None, ).to_dict(), - "module": {"type": "MBDDPG"}, + "module": {"type": "MBDDPG", "critic": {"double_q": True}}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/agents/mapo/trainer.py b/raylab/agents/mapo/trainer.py index d64eae94..abc249f7 100644 --- a/raylab/agents/mapo/trainer.py +++ b/raylab/agents/mapo/trainer.py @@ -21,6 +21,7 @@ "parallelize": False, "residual": True, }, + "critic": {"double_q": True}, }, "losses": { # Gradient estimator for optimizing expectations. Possible types include @@ -44,7 +45,6 @@ }, # === SACTorchPolicy === "target_entropy": "auto", - "clipped_double_q": True, # === TargetNetworksMixin === "polyak": 0.995, # === ModelTrainingMixin === diff --git a/raylab/agents/mbpo/trainer.py b/raylab/agents/mbpo/trainer.py index 3bbffabe..c81bee77 100644 --- a/raylab/agents/mbpo/trainer.py +++ b/raylab/agents/mbpo/trainer.py @@ -22,7 +22,10 @@ "encoder": {"units": (128, 128), "activation": "Swish"}, "input_dependent_scale": True, }, - "critic": {"encoder": {"units": (128, 128), "activation": "Swish"}}, + "critic": { + "double_q": True, + "encoder": {"units": (128, 128), "activation": "Swish"}, + }, "entropy": {"initial_alpha": 0.05}, }, "torch_optimizer": { @@ -33,7 +36,6 @@ }, # === SACTorchPolicy === "target_entropy": "auto", - "clipped_double_q": True, "polyak": 0.995, # === ModelTrainingMixin === "model_training": TrainingSpec().to_dict(), diff --git a/raylab/agents/sac/policy.py b/raylab/agents/sac/policy.py index dd17824f..810a384d 100644 --- a/raylab/agents/sac/policy.py +++ b/raylab/agents/sac/policy.py @@ -44,13 +44,6 @@ def get_default_config(): return DEFAULT_CONFIG - @override(TorchPolicy) - def make_module(self, obs_space, action_space, config): - module_config = config["module"] - module_config.setdefault("critic", {}) - module_config["critic"]["double_q"] = config["clipped_double_q"] - return super().make_module(obs_space, action_space, config) - @override(TorchPolicy) def make_optimizers(self): config = self.config["torch_optimizer"] diff --git a/raylab/agents/sac/trainer.py b/raylab/agents/sac/trainer.py index 9c55404b..9834b2f6 100644 --- a/raylab/agents/sac/trainer.py +++ b/raylab/agents/sac/trainer.py @@ -15,9 +15,6 @@ # If "auto", will use the heuristic provided in the SAC paper: # H = -dim(A), where A is the action space "target_entropy": None, - # === Twin Delayed DDPG (TD3) tricks === - # Clipped Double Q-Learning - "clipped_double_q": True, # === Optimization === # PyTorch optimizers to use "torch_optimizer": { @@ -28,7 +25,7 @@ # Interpolation factor in polyak averaging for target networks. "polyak": 0.995, # === Network === - "module": {"type": "SAC"}, + "module": {"type": "SAC", "critic": {"double_q": True}}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index bae9d1ca..0aca14a9 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -37,20 +37,6 @@ def get_default_config(): return DEFAULT_CONFIG - @override(TorchPolicy) - def make_module(self, obs_space, action_space, config): - module_config = config["module"] - module_config.setdefault("critic", {}) - module_config["critic"]["double_q"] = config["clipped_double_q"] - module_config.setdefault("actor", {}) - if ( - config["exploration_config"]["type"] - == "raylab.utils.exploration.ParameterNoise" - ): - module_config["actor"]["parameter_noise"] = True - # pylint:disable=no-member - return super().make_module(obs_space, action_space, config) - @override(TorchPolicy) def make_optimizers(self): config = self.config["torch_optimizer"] diff --git a/raylab/agents/sop/trainer.py b/raylab/agents/sop/trainer.py index 43ed20ca..d2f9618a 100644 --- a/raylab/agents/sop/trainer.py +++ b/raylab/agents/sop/trainer.py @@ -8,9 +8,6 @@ DEFAULT_CONFIG = with_base_config( { # === SOPTorchPolicy === - # Clipped Double Q-Learning: use the minimun of two target Q functions - # as the next action-value in the target for fitted Q iteration - "clipped_double_q": True, # PyTorch optimizers to use "torch_optimizer": { "actor": {"type": "Adam", "lr": 1e-3}, @@ -20,7 +17,11 @@ "polyak": 0.995, # Update policy every this number of calls to `learn_on_batch` "policy_delay": 1, - "module": {"type": "DDPG", "actor": {"separate_behavior": True}}, + "module": { + "type": "DDPG", + "actor": {"separate_behavior": True}, + "critic": {"double_q": True}, + }, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/tests/raylab/agents/sac/test_actor.py b/tests/raylab/agents/sac/test_actor.py index 97cbd8e9..9ce44bab 100644 --- a/tests/raylab/agents/sac/test_actor.py +++ b/tests/raylab/agents/sac/test_actor.py @@ -5,16 +5,23 @@ @pytest.fixture(params=(True, False)) def input_dependent_scale(request): - return {"module": {"actor": {"input_dependent_scale": request.param}}} + return request.param @pytest.fixture(params=(True, False)) -def clipped_double_q(request): - return {"clipped_double_q": request.param} - - -def test_actor_loss(policy_and_batch_fn, clipped_double_q, input_dependent_scale): - policy, batch = policy_and_batch_fn({**clipped_double_q, **input_dependent_scale}) +def double_q(request): + return request.param + + +def test_actor_loss(policy_and_batch_fn, double_q, input_dependent_scale): + policy, batch = policy_and_batch_fn( + { + "module": { + "actor": {"input_dependent_scale": input_dependent_scale}, + "critic": {"double_q": double_q}, + } + } + ) loss, info = policy.loss_actor(batch) assert loss.shape == () diff --git a/tests/raylab/agents/sac/test_critics.py b/tests/raylab/agents/sac/test_critics.py index 2535dd75..7ac909a2 100644 --- a/tests/raylab/agents/sac/test_critics.py +++ b/tests/raylab/agents/sac/test_critics.py @@ -9,13 +9,13 @@ @pytest.fixture(params=(True, False)) -def clipped_double_q(request): +def double_q(request): return request.param @pytest.fixture -def policy_and_batch(policy_and_batch_fn, clipped_double_q): - config = {"clipped_double_q": clipped_double_q, "polyak": 0.5} +def policy_and_batch(policy_and_batch_fn, double_q): + config = {"module": {"critic": {"double_q": double_q}}, "polyak": 0.5} return policy_and_batch_fn(config) diff --git a/tests/raylab/agents/sop/test_policy.py b/tests/raylab/agents/sop/test_policy.py index 0fa414da..94453900 100644 --- a/tests/raylab/agents/sop/test_policy.py +++ b/tests/raylab/agents/sop/test_policy.py @@ -7,13 +7,13 @@ @pytest.fixture(params=(True, False)) -def clipped_double_q(request): +def double_q(request): return request.param @pytest.fixture -def config(clipped_double_q): - return {"clipped_double_q": clipped_double_q, "policy_delay": 2} +def config(double_q): + return {"module": {"critic": {"double_q": double_q}}, "policy_delay": 2} @pytest.fixture