diff --git a/examples/MAPO/swingup.py b/examples/MAPO/swingup.py index a2310707..3ca0af5c 100644 --- a/examples/MAPO/swingup.py +++ b/examples/MAPO/swingup.py @@ -8,14 +8,23 @@ def get_config(): "env_config": {"max_episode_steps": 500, "time_aware": False}, # === MAPOTorchPolicy === "module": { - "type": "ModelBasedSAC", + "type": "MBSAC", "model": { "ensemble_size": 1, + "residual": True, + "input_dependent_scale": True, + "network": {"units": (128, 128), "activation": "Swish"}, + }, + "actor": { + "encoder": {"units": (128, 128), "activation": "Swish"}, + "input_dependent_scale": True, + "initial_entropy_coeff": 0.05, + }, + "critic": { "encoder": {"units": (128, 128), "activation": "Swish"}, + "double_q": True, }, - "actor": {"encoder": {"units": (128, 128), "activation": "Swish"}}, - "critic": {"encoder": {"units": (128, 128), "activation": "Swish"}}, - "entropy": {"initial_alpha": 0.05}, + "initializer": {"name": "xavier_uniform"}, }, "losses": { # Gradient estimator for optimizing expectations. Possible types include diff --git a/examples/MBPO/swingup.py b/examples/MBPO/swingup.py index 48f7136e..d5d0bb4e 100644 --- a/examples/MBPO/swingup.py +++ b/examples/MBPO/swingup.py @@ -8,18 +8,24 @@ def get_config(): "env_config": {"max_episode_steps": 500, "time_aware": False}, # === MBPOTorchPolicy === "module": { - "type": "ModelBasedSAC", + "type": "MBSAC", "model": { - "encoder": {"units": (128, 128), "activation": "Swish"}, "ensemble_size": 7, + "parallelize": True, + "residual": True, "input_dependent_scale": True, + "network": {"units": (128, 128), "activation": "Swish"}, }, "actor": { "encoder": {"units": (128, 128), "activation": "Swish"}, "input_dependent_scale": True, + "initial_entropy_coeff": 0.05, + }, + "critic": { + "encoder": {"units": (128, 128), "activation": "Swish"}, + "double_q": True, }, - "critic": {"encoder": {"units": (128, 128), "activation": "Swish"}}, - "entropy": {"initial_alpha": 0.05}, + "initializer": {"name": "xavier_uniform"}, }, "torch_optimizer": { "models": {"type": "Adam", "lr": 3e-4, "weight_decay": 0.0001}, @@ -63,6 +69,7 @@ def get_config(): "evaluation_num_episodes": 10, "timesteps_per_iteration": 1000, "num_cpus_for_driver": 4, + "compile_policy": True, # === RolloutWorker === "rollout_fragment_length": 25, "batch_mode": "truncate_episodes", diff --git a/examples/SOP/cheetah_defaults.py b/examples/SOP/cheetah_defaults.py index 93e7c63d..5c2a8ba3 100644 --- a/examples/SOP/cheetah_defaults.py +++ b/examples/SOP/cheetah_defaults.py @@ -25,15 +25,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "orthogonal"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ELU", - "initializer_options": {"name": "Orthogonal"}, "layer_norm": True, }, }, @@ -41,7 +42,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ELU", - "initializer_options": {"name": "Orthogonal"}, "delay_action": True, }, }, diff --git a/examples/SOP/hopper_defaults.py b/examples/SOP/hopper_defaults.py index 8840efc8..c27df9d2 100644 --- a/examples/SOP/hopper_defaults.py +++ b/examples/SOP/hopper_defaults.py @@ -27,15 +27,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "orthogonal"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "layer_norm": False, }, }, @@ -44,7 +45,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "delay_action": True, }, }, diff --git a/examples/SOP/ib_defaults.py b/examples/SOP/ib_defaults.py index f6fc72e1..51c6d4e1 100644 --- a/examples/SOP/ib_defaults.py +++ b/examples/SOP/ib_defaults.py @@ -30,15 +30,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "xavier_uniform"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "layer_norm": False, }, }, @@ -47,7 +48,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "delay_action": True, }, }, diff --git a/examples/SOP/swingup_defaults.py b/examples/SOP/swingup_defaults.py index 0c1528d0..afe1e87d 100644 --- a/examples/SOP/swingup_defaults.py +++ b/examples/SOP/swingup_defaults.py @@ -19,8 +19,9 @@ def get_config(): "polyak": 0.995, # === Network === "module": { - "type": "DDPGModule", + "type": "DDPG", "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.3, "beta": 1.2, diff --git a/examples/SOP/walker_defaults.py b/examples/SOP/walker_defaults.py index 1b8d3e63..8e575342 100644 --- a/examples/SOP/walker_defaults.py +++ b/examples/SOP/walker_defaults.py @@ -42,15 +42,16 @@ def get_config(): # for the policy and action-value function. No layers means the component is # linear in states and/or actions. "module": { - "type": "DDPGModule", + "type": "DDPG", + "initializer": {"name": "xavier_uniform"}, "actor": { + "parameter_noise": True, "smooth_target_policy": True, "target_gaussian_sigma": 0.2, "beta": 1.2, "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "layer_norm": False, }, }, @@ -59,7 +60,6 @@ def get_config(): "encoder": { "units": (256, 256), "activation": "ReLU", - "initializer_options": {"name": "xavier_uniform"}, "delay_action": True, }, }, diff --git a/poetry.lock b/poetry.lock index 4e20abd7..69d1c0a5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -298,7 +298,7 @@ description = "Extensible memoizing collections and decorators" name = "cachetools" optional = false python-versions = "~=3.5" -version = "4.1.0" +version = "4.1.1" [[package]] category = "main" @@ -619,7 +619,7 @@ description = "IPython: Productive Interactive Computing" name = "ipython" optional = false python-versions = ">=3.6" -version = "7.15.0" +version = "7.16.1" [package.dependencies] appnope = "*" @@ -1128,7 +1128,7 @@ description = "A framework for managing and maintaining multi-language pre-commi name = "pre-commit" optional = false python-versions = ">=3.6.1" -version = "2.5.1" +version = "2.6.0" [package.dependencies] cfgv = ">=2.0.0" @@ -1688,7 +1688,7 @@ description = "Frontend library for machine learning engineers" name = "streamlit" optional = false python-versions = ">=3.6" -version = "0.62.0" +version = "0.62.1" [package.dependencies] altair = ">=3.2.0" @@ -1828,7 +1828,7 @@ description = "tox is a generic virtualenv management and test command line tool name = "tox" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" -version = "3.15.2" +version = "3.16.1" [package.dependencies] colorama = ">=0.4.1" @@ -1959,7 +1959,7 @@ description = "Filesystem events monitoring" name = "watchdog" optional = false python-versions = "*" -version = "0.10.2" +version = "0.10.3" [package.dependencies] pathtools = ">=0.1.1" @@ -2028,7 +2028,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "07c41546f29a2277543366994ebd8a11e660438d3c515a08863b3af5bb90769e" +content-hash = "4b8029fbc1de1ec8d50af443ac0d298a7b7ec72d7c228b653997500d3887a826" python-versions = "^3.7" [metadata.files] @@ -2145,8 +2145,8 @@ cached-property = [ {file = "cached_property-1.5.1-py2.py3-none-any.whl", hash = "sha256:3a026f1a54135677e7da5ce819b0c690f156f37976f3e30c5430740725203d7f"}, ] cachetools = [ - {file = "cachetools-4.1.0-py3-none-any.whl", hash = "sha256:de5d88f87781602201cde465d3afe837546663b168e8b39df67411b0bf10cefc"}, - {file = "cachetools-4.1.0.tar.gz", hash = "sha256:1d057645db16ca7fe1f3bd953558897603d6f0b9c51ed9d11eb4d071ec4e2aab"}, + {file = "cachetools-4.1.1-py3-none-any.whl", hash = "sha256:513d4ff98dd27f85743a8dc0e92f55ddb1b49e060c2d5961512855cda2c01a98"}, + {file = "cachetools-4.1.1.tar.gz", hash = "sha256:bbaa39c3dede00175df2dc2b03d0cf18dd2d32a7de7beb68072d13043c9edb20"}, ] certifi = [ {file = "certifi-2020.6.20-py2.py3-none-any.whl", hash = "sha256:8fc0819f1f30ba15bdb34cceffb9ef04d99f420f68eb75d901e9560b8749fc41"}, @@ -2337,8 +2337,8 @@ ipykernel = [ {file = "ipykernel-5.3.0.tar.gz", hash = "sha256:731adb3f2c4ebcaff52e10a855ddc87670359a89c9c784d711e62d66fccdafae"}, ] ipython = [ - {file = "ipython-7.15.0-py3-none-any.whl", hash = "sha256:1b85d65632211bf5d3e6f1406f3393c8c429a47d7b947b9a87812aa5bce6595c"}, - {file = "ipython-7.15.0.tar.gz", hash = "sha256:0ef1433879816a960cd3ae1ae1dc82c64732ca75cec8dab5a4e29783fb571d0e"}, + {file = "ipython-7.16.1-py3-none-any.whl", hash = "sha256:2dbcc8c27ca7d3cfe4fcdff7f45b27f9a8d3edfa70ff8024a71c7a8eb5f09d64"}, + {file = "ipython-7.16.1.tar.gz", hash = "sha256:9f4fcb31d3b2c533333893b9172264e4821c1ac91839500f31bd43f2c59b3ccf"}, ] ipython-genutils = [ {file = "ipython_genutils-0.2.0-py2.py3-none-any.whl", hash = "sha256:72dd37233799e619666c9f639a9da83c34013a73e8bbc79a7a6348d93c61fab8"}, @@ -2712,8 +2712,8 @@ poetry-version = [ {file = "poetry_version-0.1.5-py2.py3-none-any.whl", hash = "sha256:ba259257640cd36c76375563a001b9e85c6f54d764cc56b3eed0f3b5cefb0317"}, ] pre-commit = [ - {file = "pre_commit-2.5.1-py2.py3-none-any.whl", hash = "sha256:c5c8fd4d0e1c363723aaf0a8f9cba0f434c160b48c4028f4bae6d219177945b3"}, - {file = "pre_commit-2.5.1.tar.gz", hash = "sha256:da463cf8f0e257f9af49047ba514f6b90dbd9b4f92f4c8847a3ccd36834874c7"}, + {file = "pre_commit-2.6.0-py2.py3-none-any.whl", hash = "sha256:e8b1315c585052e729ab7e99dcca5698266bedce9067d21dc909c23e3ceed626"}, + {file = "pre_commit-2.6.0.tar.gz", hash = "sha256:1657663fdd63a321a4a739915d7d03baedd555b25054449090f97bb0cb30a915"}, ] prometheus-client = [ {file = "prometheus_client-0.8.0-py2.py3-none-any.whl", hash = "sha256:983c7ac4b47478720db338f1491ef67a100b474e3bc7dafcbaefb7d0b8f9b01c"}, @@ -2997,7 +2997,7 @@ sphinxcontrib-serializinghtml = [ {file = "sphinxcontrib_serializinghtml-1.1.4-py2.py3-none-any.whl", hash = "sha256:f242a81d423f59617a8e5cf16f5d4d74e28ee9a66f9e5b637a18082991db5a9a"}, ] streamlit = [ - {file = "streamlit-0.62.0-py2.py3-none-any.whl", hash = "sha256:114403a3f10885979744eb4cd5d9dccd4f3f606be55dd81baf6f0b2f47e4de06"}, + {file = "streamlit-0.62.1-py2.py3-none-any.whl", hash = "sha256:5b5e5219b103276bd2ad6005659de9381ed044e116bbaf5bdac7da5424f1a309"}, ] stringcase = [ {file = "stringcase-1.2.0.tar.gz", hash = "sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"}, @@ -3054,8 +3054,8 @@ tornado = [ {file = "tornado-6.0.4.tar.gz", hash = "sha256:0fe2d45ba43b00a41cd73f8be321a44936dc1aba233dee979f17a042b83eb6dc"}, ] tox = [ - {file = "tox-3.15.2-py2.py3-none-any.whl", hash = "sha256:50a188b8e17580c1fb931f494a754e6507d4185f54fb18aca5ba3e12d2ffd55e"}, - {file = "tox-3.15.2.tar.gz", hash = "sha256:c696d36cd7c6a28ada2da780400e44851b20ee19ef08cfe73344a1dcebbbe9f3"}, + {file = "tox-3.16.1-py2.py3-none-any.whl", hash = "sha256:60c3793f8ab194097ec75b5a9866138444f63742b0f664ec80be1222a40687c5"}, + {file = "tox-3.16.1.tar.gz", hash = "sha256:9a746cda9cadb9e1e05c7ab99f98cfcea355140d2ecac5f97520be94657c3bc7"}, ] traitlets = [ {file = "traitlets-4.3.3-py2.py3-none-any.whl", hash = "sha256:70b4c6a1d9019d7b4f6846832288f86998aa3b9207c6821f3578a6a6a467fe44"}, @@ -3110,7 +3110,7 @@ virtualenv = [ {file = "virtualenv-20.0.23.tar.gz", hash = "sha256:5102fbf1ec57e80671ef40ed98a84e980a71194cedf30c87c2b25c3a9e0b0107"}, ] watchdog = [ - {file = "watchdog-0.10.2.tar.gz", hash = "sha256:c560efb643faed5ef28784b2245cf8874f939569717a4a12826a173ac644456b"}, + {file = "watchdog-0.10.3.tar.gz", hash = "sha256:4214e1379d128b0588021880ccaf40317ee156d4603ac388b9adcf29165e0c04"}, ] wcwidth = [ {file = "wcwidth-0.2.4-py2.py3-none-any.whl", hash = "sha256:79375666b9954d4a1a10739315816324c3e73110af9d0e102d906fdb0aec009f"}, diff --git a/pyproject.toml b/pyproject.toml index e72e0cdd..febc2c49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "raylab" -version = "0.8.5" +version = "0.9.0" description = "Reinforcement learning algorithms in RLlib and PyTorch." authors = ["Ângelo Gregório Lovatto "] license = "MIT" @@ -27,17 +27,17 @@ opencv-python = "^4.2.0" [tool.poetry.dev-dependencies] flake8 = "^3.8.3" pylint = "^2.5.3" -watchdog = "^0.10.2" +watchdog = "^0.10.3" black = "^19.10b0" -tox = "^3.15.2" +tox = "^3.16.1" sphinx = "^3.1.1" pytest = "^5.4.3" gym-cartpole-swingup = "^0.1.0" -pre-commit = "^2.5.1" +pre-commit = "^2.6.0" reorder-python-imports = "^2.3.1" mypy = "^0.782" coverage = "^5.1" -ipython = "^7.15.0" +ipython = "^7.16.1" poetry-version = "^0.1.5" pytest-mock = "^3.1.1" pytest-sugar = "^0.9.3" diff --git a/raylab/agents/acktr/policy.py b/raylab/agents/acktr/policy.py index 293fe1a2..2900c615 100644 --- a/raylab/agents/acktr/policy.py +++ b/raylab/agents/acktr/policy.py @@ -12,6 +12,7 @@ import raylab.utils.dictionaries as dutil from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.pytorch.nn.distributions import Normal from raylab.pytorch.optim import build_optimizer from raylab.pytorch.optim.hessian_free import line_search @@ -51,6 +52,7 @@ class ACKTRTorchPolicy(TorchPolicy): """Policy class for Actor-Critic with Kronecker factored Trust Region.""" # pylint:disable=abstract-method + dist_class = WrapStochasticPolicy @staticmethod @override(TorchPolicy) diff --git a/raylab/agents/mage/policy.py b/raylab/agents/mage/policy.py index 8672fce1..f28387a3 100644 --- a/raylab/agents/mage/policy.py +++ b/raylab/agents/mage/policy.py @@ -1,10 +1,11 @@ """Policy for MAGE using PyTorch.""" from raylab.agents.sop import SOPTorchPolicy -from raylab.losses import MAGE -from raylab.losses import ModelEnsembleMLE -from raylab.losses.mage import MAGEModules from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.action_dist import WrapDeterministicPolicy +from raylab.policy.losses import MAGE +from raylab.policy.losses import ModelEnsembleMLE +from raylab.policy.losses.mage import MAGEModules from raylab.pytorch.optim import build_optimizer @@ -18,6 +19,7 @@ class MAGETorchPolicy(ModelTrainingMixin, EnvFnMixin, SOPTorchPolicy): """ # pylint: disable=abstract-method + dist_class = WrapDeterministicPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) @@ -37,6 +39,7 @@ def __init__(self, observation_space, action_space, config): @staticmethod def get_default_config(): + # pylint:disable=cyclic-import from raylab.agents.mage import DEFAULT_CONFIG return DEFAULT_CONFIG diff --git a/raylab/agents/mage/trainer.py b/raylab/agents/mage/trainer.py index 4975b776..52b36d7d 100644 --- a/raylab/agents/mage/trainer.py +++ b/raylab/agents/mage/trainer.py @@ -16,9 +16,6 @@ "policy_improvements": 10, "real_data_ratio": 1, # === MAGETorchPolicy === - # Clipped Double Q-Learning: use the minimun of two target Q functions - # as the next action-value in the target for fitted Q iteration - "clipped_double_q": True, # TD error regularization for MAGE loss "lambda": 0.05, # PyTorch optimizers to use @@ -39,7 +36,7 @@ patience_epochs=None, improvement_threshold=None, ).to_dict(), - "module": {"type": "ModelBasedDDPG", "model": {"ensemble_size": 1}}, + "module": {"type": "MBDDPG", "critic": {"double_q": True}}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/agents/mapo/policy.py b/raylab/agents/mapo/policy.py index aade2abd..79efeb21 100644 --- a/raylab/agents/mapo/policy.py +++ b/raylab/agents/mapo/policy.py @@ -2,11 +2,12 @@ from ray.rllib.utils import override from raylab.agents.sac import SACTorchPolicy -from raylab.losses import DAPO -from raylab.losses import MAPO -from raylab.losses import SPAML from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.action_dist import WrapStochasticPolicy +from raylab.policy.losses import DAPO +from raylab.policy.losses import MAPO +from raylab.policy.losses import SPAML from raylab.pytorch.optim import build_optimizer @@ -14,10 +15,10 @@ class MAPOTorchPolicy(ModelTrainingMixin, EnvFnMixin, SACTorchPolicy): """Model-Aware Policy Optimization policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapStochasticPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) - self.loss_model = SPAML( self.module.models, self.module.actor, self.module.critics ) diff --git a/raylab/agents/mapo/trainer.py b/raylab/agents/mapo/trainer.py index 3d03926e..abc249f7 100644 --- a/raylab/agents/mapo/trainer.py +++ b/raylab/agents/mapo/trainer.py @@ -12,7 +12,17 @@ DEFAULT_CONFIG = with_base_config( { # === MAPOTorchPolicy === - "module": {"type": "ModelBasedSAC", "model": {"ensemble_size": 1}}, + "module": { + "type": "MBSAC", + "model": { + "network": {"units": (128, 128), "activation": "Swish"}, + "ensemble_size": 1, + "input_dependent_scale": True, + "parallelize": False, + "residual": True, + }, + "critic": {"double_q": True}, + }, "losses": { # Gradient estimator for optimizing expectations. Possible types include # SF: score function @@ -35,7 +45,6 @@ }, # === SACTorchPolicy === "target_entropy": "auto", - "clipped_double_q": True, # === TargetNetworksMixin === "polyak": 0.995, # === ModelTrainingMixin === diff --git a/raylab/agents/mbpo/policy.py b/raylab/agents/mbpo/policy.py index 17beb491..f93747ef 100644 --- a/raylab/agents/mbpo/policy.py +++ b/raylab/agents/mbpo/policy.py @@ -2,10 +2,11 @@ from ray.rllib.utils import override from raylab.agents.sac import SACTorchPolicy -from raylab.losses import ModelEnsembleMLE from raylab.policy import EnvFnMixin from raylab.policy import ModelSamplingMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.action_dist import WrapStochasticPolicy +from raylab.policy.losses import ModelEnsembleMLE from raylab.pytorch.optim import build_optimizer @@ -15,10 +16,10 @@ class MBPOTorchPolicy( """Model-Based Policy Optimization policy in PyTorch to use with RLlib.""" # pylint:disable=abstract-method,too-many-ancestors + dist_class = WrapStochasticPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) - models = self.module.models self.loss_model = ModelEnsembleMLE(models) @@ -37,6 +38,6 @@ def make_optimizers(self): components = "models actor critics alpha".split() return { - name: build_optimizer(self.module[name], config[name]) + name: build_optimizer(getattr(self.module, name), config[name]) for name in components } diff --git a/raylab/agents/mbpo/trainer.py b/raylab/agents/mbpo/trainer.py index 30c49918..c81bee77 100644 --- a/raylab/agents/mbpo/trainer.py +++ b/raylab/agents/mbpo/trainer.py @@ -10,17 +10,22 @@ { # === MBPOTorchPolicy === "module": { - "type": "ModelBasedSAC", + "type": "MBSAC", "model": { - "encoder": {"units": (128, 128), "activation": "Swish"}, + "network": {"units": (128, 128), "activation": "Swish"}, "ensemble_size": 7, "input_dependent_scale": True, + "parallelize": True, + "residual": True, }, "actor": { "encoder": {"units": (128, 128), "activation": "Swish"}, "input_dependent_scale": True, }, - "critic": {"encoder": {"units": (128, 128), "activation": "Swish"}}, + "critic": { + "double_q": True, + "encoder": {"units": (128, 128), "activation": "Swish"}, + }, "entropy": {"initial_alpha": 0.05}, }, "torch_optimizer": { @@ -31,7 +36,6 @@ }, # === SACTorchPolicy === "target_entropy": "auto", - "clipped_double_q": True, "polyak": 0.995, # === ModelTrainingMixin === "model_training": TrainingSpec().to_dict(), @@ -48,6 +52,8 @@ "learning_starts": 5000, # === OffPolicyTrainer === "train_batch_size": 512, + # === Trainer === + "compile_policy": True, } ) diff --git a/raylab/agents/naf/policy.py b/raylab/agents/naf/policy.py index 5931b204..daebe840 100644 --- a/raylab/agents/naf/policy.py +++ b/raylab/agents/naf/policy.py @@ -3,16 +3,18 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.losses import ClippedDoubleQLearning -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapDeterministicPolicy +from raylab.policy.losses import ClippedDoubleQLearning +from raylab.pytorch.nn.utils import update_polyak from raylab.pytorch.optim import build_optimizer -class NAFTorchPolicy(TargetNetworksMixin, TorchPolicy): +class NAFTorchPolicy(TorchPolicy): """Normalized Advantage Function policy in Pytorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapDeterministicPolicy def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -57,7 +59,9 @@ def learn_on_batch(self, samples): loss.backward() info.update(self.extra_grad_info()) - self.update_targets("vcritics", "target_vcritics") + update_polyak( + self.module.vcritics, self.module.target_vcritics, self.config["polyak"] + ) return info @torch.no_grad() diff --git a/raylab/agents/sac/policy.py b/raylab/agents/sac/policy.py index c2cc425a..810a384d 100644 --- a/raylab/agents/sac/policy.py +++ b/raylab/agents/sac/policy.py @@ -3,18 +3,20 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.losses import MaximumEntropyDual -from raylab.losses import ReparameterizedSoftPG -from raylab.losses import SoftCDQLearning -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy +from raylab.policy.losses import MaximumEntropyDual +from raylab.policy.losses import ReparameterizedSoftPG +from raylab.policy.losses import SoftCDQLearning +from raylab.pytorch.nn.utils import update_polyak from raylab.pytorch.optim import build_optimizer -class SACTorchPolicy(TargetNetworksMixin, TorchPolicy): +class SACTorchPolicy(TorchPolicy): """Soft Actor-Critic policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapStochasticPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) @@ -42,20 +44,13 @@ def get_default_config(): return DEFAULT_CONFIG - @override(TorchPolicy) - def make_module(self, obs_space, action_space, config): - module_config = config["module"] - module_config.setdefault("critic", {}) - module_config["critic"]["double_q"] = config["clipped_double_q"] - return super().make_module(obs_space, action_space, config) - @override(TorchPolicy) def make_optimizers(self): config = self.config["torch_optimizer"] components = "actor critics alpha".split() return { - name: build_optimizer(self.module[name], config[name]) + name: build_optimizer(getattr(self.module, name), config[name]) for name in components } @@ -73,7 +68,9 @@ def learn_on_batch(self, samples): if self.config["target_entropy"] is not None: info.update(self._update_alpha(batch_tensors)) - self.update_targets("critics", "target_critics") + update_polyak( + self.module.critics, self.module.target_critics, self.config["polyak"] + ) return info def _update_critic(self, batch_tensors): @@ -105,6 +102,6 @@ def extra_grad_info(self, component): """Return statistics right after components are updated.""" return { f"grad_norm({component})": nn.utils.clip_grad_norm_( - self.module[component].parameters(), float("inf") + getattr(self.module, component).parameters(), float("inf") ).item() } diff --git a/raylab/agents/sac/trainer.py b/raylab/agents/sac/trainer.py index 5ec3c9b0..9834b2f6 100644 --- a/raylab/agents/sac/trainer.py +++ b/raylab/agents/sac/trainer.py @@ -15,9 +15,6 @@ # If "auto", will use the heuristic provided in the SAC paper: # H = -dim(A), where A is the action space "target_entropy": None, - # === Twin Delayed DDPG (TD3) tricks === - # Clipped Double Q-Learning - "clipped_double_q": True, # === Optimization === # PyTorch optimizers to use "torch_optimizer": { @@ -28,10 +25,7 @@ # Interpolation factor in polyak averaging for target networks. "polyak": 0.995, # === Network === - # Size and activation of the fully connected networks computing the logits - # for the policy and action-value function. No layers means the component is - # linear in states and/or actions. - "module": {"type": "SACModule"}, + "module": {"type": "SAC", "critic": {"double_q": True}}, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/agents/sop/policy.py b/raylab/agents/sop/policy.py index 522271f8..0aca14a9 100644 --- a/raylab/agents/sop/policy.py +++ b/raylab/agents/sop/policy.py @@ -3,17 +3,19 @@ import torch.nn as nn from ray.rllib.utils import override -from raylab.losses import ClippedDoubleQLearning -from raylab.losses import DeterministicPolicyGradient -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapDeterministicPolicy +from raylab.policy.losses import ClippedDoubleQLearning +from raylab.policy.losses import DeterministicPolicyGradient +from raylab.pytorch.nn.utils import update_polyak from raylab.pytorch.optim import build_optimizer -class SOPTorchPolicy(TargetNetworksMixin, TorchPolicy): +class SOPTorchPolicy(TorchPolicy): """Streamlined Off-Policy policy in PyTorch to use with RLlib.""" # pylint: disable=abstract-method + dist_class = WrapDeterministicPolicy def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -35,26 +37,13 @@ def get_default_config(): return DEFAULT_CONFIG - @override(TorchPolicy) - def make_module(self, obs_space, action_space, config): - module_config = config["module"] - module_config.setdefault("critic", {}) - module_config["critic"]["double_q"] = config["clipped_double_q"] - module_config.setdefault("actor", {}) - module_config["actor"]["perturbed_policy"] = ( - config["exploration_config"]["type"] - == "raylab.utils.exploration.ParameterNoise" - ) - # pylint:disable=no-member - return super().make_module(obs_space, action_space, config) - @override(TorchPolicy) def make_optimizers(self): config = self.config["torch_optimizer"] components = "actor critics".split() return { - name: build_optimizer(self.module[name], config[name]) + name: build_optimizer(getattr(self.module, name), config[name]) for name in components } @@ -68,7 +57,9 @@ def learn_on_batch(self, samples): if self._grad_step % self.config["policy_delay"] == 0: info.update(self._update_policy(batch_tensors)) - self.update_targets("critics", "target_critics") + update_polyak( + self.module.critics, self.module.target_critics, self.config["polyak"] + ) return info def _update_critic(self, batch_tensors): @@ -92,7 +83,7 @@ def extra_grad_info(self, component): """Return statistics right after components are updated.""" return { f"grad_norm({component})": nn.utils.clip_grad_norm_( - self.module[component].parameters(), float("inf") + getattr(self.module, component).parameters(), float("inf") ).item() } diff --git a/raylab/agents/sop/trainer.py b/raylab/agents/sop/trainer.py index 3ea295fd..d2f9618a 100644 --- a/raylab/agents/sop/trainer.py +++ b/raylab/agents/sop/trainer.py @@ -8,9 +8,6 @@ DEFAULT_CONFIG = with_base_config( { # === SOPTorchPolicy === - # Clipped Double Q-Learning: use the minimun of two target Q functions - # as the next action-value in the target for fitted Q iteration - "clipped_double_q": True, # PyTorch optimizers to use "torch_optimizer": { "actor": {"type": "Adam", "lr": 1e-3}, @@ -20,7 +17,11 @@ "polyak": 0.995, # Update policy every this number of calls to `learn_on_batch` "policy_delay": 1, - "module": {"type": "DDPGModule"}, + "module": { + "type": "DDPG", + "actor": {"separate_behavior": True}, + "critic": {"double_q": True}, + }, # === Exploration Settings === # Default exploration behavior, iff `explore`=None is passed into # compute_action(s). diff --git a/raylab/agents/svg/inf/policy.py b/raylab/agents/svg/inf/policy.py index b018ce7b..0c0c6075 100644 --- a/raylab/agents/svg/inf/policy.py +++ b/raylab/agents/svg/inf/policy.py @@ -7,9 +7,9 @@ from ray.rllib.utils import override from raylab.agents.svg import SVGTorchPolicy -from raylab.losses import TrajectorySVG from raylab.policy import AdaptiveKLCoeffMixin from raylab.policy import EnvFnMixin +from raylab.policy.losses import TrajectorySVG from raylab.pytorch.optim import build_optimizer @@ -88,7 +88,7 @@ def _learn_off_policy(self, batch_tensors): info.update(_info) loss.backward() - self.update_targets("critic", "target_critic") + self._update_polyak() return info def _learn_on_policy(self, batch_tensors, samples): diff --git a/raylab/agents/svg/one/policy.py b/raylab/agents/svg/one/policy.py index 49cc7df1..783cc7cd 100644 --- a/raylab/agents/svg/one/policy.py +++ b/raylab/agents/svg/one/policy.py @@ -7,10 +7,10 @@ from ray.rllib.utils import override from raylab.agents.svg import SVGTorchPolicy -from raylab.losses import OneStepSVG from raylab.policy import AdaptiveKLCoeffMixin from raylab.policy import EnvFnMixin from raylab.policy import TorchPolicy +from raylab.policy.losses import OneStepSVG from raylab.pytorch.optim import get_optimizer_class @@ -95,7 +95,7 @@ def learn_on_batch(self, samples): info.update(self.extra_grad_info(batch_tensors)) info.update(self.update_kl_coeff(samples)) - self.update_targets("critic", "target_critic") + self._update_polyak() return info @torch.no_grad() diff --git a/raylab/agents/svg/policy.py b/raylab/agents/svg/policy.py index e9b5fe98..d2322fc4 100644 --- a/raylab/agents/svg/policy.py +++ b/raylab/agents/svg/policy.py @@ -2,17 +2,20 @@ import torch from ray.rllib import SampleBatch -from raylab.losses import ISFittedVIteration -from raylab.losses import MaximumLikelihood from raylab.policy import EnvFnMixin -from raylab.policy import TargetNetworksMixin from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy +from raylab.policy.losses import ISFittedVIteration +from raylab.policy.losses import MaximumLikelihood +from raylab.pytorch.nn.utils import update_polyak -class SVGTorchPolicy(EnvFnMixin, TargetNetworksMixin, TorchPolicy): +class SVGTorchPolicy(EnvFnMixin, TorchPolicy): """Stochastic Value Gradients policy using PyTorch.""" # pylint: disable=abstract-method + dist_class = WrapStochasticPolicy + def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) self.loss_model = MaximumLikelihood(self.module.model) @@ -44,3 +47,8 @@ def compute_joint_model_value_loss(self, batch_tensors): loss = mle_loss + self.config["vf_loss_coeff"] * isfv_loss return loss, {**mle_info, **isfv_info} + + def _update_polyak(self): + update_polyak( + self.module.critic, self.module.target_critic, self.config["polyak"] + ) diff --git a/raylab/agents/svg/soft/policy.py b/raylab/agents/svg/soft/policy.py index f1f8e4a2..6786c6ab 100644 --- a/raylab/agents/svg/soft/policy.py +++ b/raylab/agents/svg/soft/policy.py @@ -5,10 +5,10 @@ from ray.rllib.utils import override from raylab.agents.svg import SVGTorchPolicy -from raylab.losses import ISSoftVIteration -from raylab.losses import MaximumEntropyDual -from raylab.losses import OneStepSoftSVG from raylab.policy import EnvFnMixin +from raylab.policy.losses import ISSoftVIteration +from raylab.policy.losses import MaximumEntropyDual +from raylab.policy.losses import OneStepSoftSVG from raylab.pytorch.optim import build_optimizer @@ -109,7 +109,7 @@ def learn_on_batch(self, samples): if self.config["target_entropy"] is not None: info.update(self._update_alpha(batch_tensors)) - self.update_targets("critic", "target_critic") + self._update_polyak() return info def _update_model(self, batch_tensors): diff --git a/raylab/agents/trpo/policy.py b/raylab/agents/trpo/policy.py index d8d3d691..b9ea9904 100644 --- a/raylab/agents/trpo/policy.py +++ b/raylab/agents/trpo/policy.py @@ -11,6 +11,7 @@ from torch.nn.utils import vector_to_parameters from raylab.policy import TorchPolicy +from raylab.policy.action_dist import WrapStochasticPolicy from raylab.pytorch.optim import build_optimizer from raylab.pytorch.optim.hessian_free import conjugate_gradient from raylab.pytorch.optim.hessian_free import hessian_vector_product @@ -24,6 +25,7 @@ class TRPOTorchPolicy(TorchPolicy): """Policy class for Trust Region Policy Optimization.""" # pylint:disable=abstract-method + dist_class = WrapStochasticPolicy @staticmethod @override(TorchPolicy) diff --git a/raylab/modules/catalog.py b/raylab/modules/catalog.py deleted file mode 100644 index 880fa0c6..00000000 --- a/raylab/modules/catalog.py +++ /dev/null @@ -1,40 +0,0 @@ -"""Registry of modules for PyTorch policies.""" - -from .ddpg_module import DDPGModule -from .maxent_model_based import MaxEntModelBased -from .model_based_ddpg import ModelBasedDDPG -from .model_based_sac import ModelBasedSAC -from .naf_module import NAFModule -from .nfmbrl import NFMBRL -from .off_policy_nfac import OffPolicyNFAC -from .on_policy_actor_critic import OnPolicyActorCritic -from .on_policy_nfac import OnPolicyNFAC -from .sac_module import SACModule -from .simple_model_based import SimpleModelBased -from .svg_module import SVGModule -from .svg_realnvp_actor import SVGRealNVPActor -from .trpo_tang2018 import TRPOTang2018 - -MODULES = { - "NAFModule": NAFModule, - "DDPGModule": DDPGModule, - "SACModule": SACModule, - "SimpleModelBased": SimpleModelBased, - "SVGModule": SVGModule, - "MaxEntModelBased": MaxEntModelBased, - "ModelBasedDDPG": ModelBasedDDPG, - "ModelBasedSAC": ModelBasedSAC, - "NFMBRL": NFMBRL, - "OnPolicyActorCritic": OnPolicyActorCritic, - "OnPolicyNFAC": OnPolicyNFAC, - "OffPolicyNFAC": OffPolicyNFAC, - "TRPOTang2018": TRPOTang2018, - "SVGRealNVPActor": SVGRealNVPActor, -} - - -def get_module(obs_space, action_space, config): - """Retrieve and construct module of given name.""" - type_ = config.pop("type") - module = MODULES[type_](obs_space, action_space, config) - return module diff --git a/raylab/modules/networks/mlp.py b/raylab/modules/networks/mlp.py deleted file mode 100644 index 9eb74c0c..00000000 --- a/raylab/modules/networks/mlp.py +++ /dev/null @@ -1,57 +0,0 @@ -# pylint:disable=missing-module-docstring -from typing import Dict -from typing import Optional - -import torch -import torch.nn as nn - -from raylab.pytorch.nn.utils import get_activation - - -class MLP(nn.Module): - """A general purpose Multi-Layer Perceptron.""" - - def __init__( - self, - in_features, - out_features, - hidden_features, - state_features=None, - num_blocks=2, - activation="ReLU", - activate_output=False, - ): - # pylint:disable=too-many-arguments - super().__init__() - activation = get_activation(activation) - self.stateful = bool(state_features) - if self.stateful: - self.initial_layer = nn.Linear( - in_features + state_features, hidden_features - ) - else: - self.initial_layer = nn.Linear(in_features, hidden_features) - - layers = [activation()] - layers += [ - layer - for _ in range(num_blocks) - for layer in (nn.Linear(hidden_features, hidden_features), activation()) - ] - layers += [nn.Linear(hidden_features, out_features)] - - if activate_output: - layers += [activation()] - - self.sequential = nn.Sequential(*layers) - - def forward(self, inputs, params: Optional[Dict[str, torch.Tensor]] = None): - # pylint:disable=arguments-differ - if self.stateful: - if params is None: - raise ValueError("Parameters required for stateful mlp.") - out = self.initial_layer(torch.cat([inputs, params["state"]], dim=-1)) - else: - out = self.initial_layer(inputs) - - return self.sequential(out) diff --git a/raylab/policy/__init__.py b/raylab/policy/__init__.py index 799eb296..3c6178c8 100644 --- a/raylab/policy/__init__.py +++ b/raylab/policy/__init__.py @@ -5,7 +5,6 @@ from .model_based import ModelSamplingMixin from .model_based import ModelTrainingMixin from .optimizer_collection import OptimizerCollection -from .target_networks_mixin import TargetNetworksMixin from .torch_policy import TorchPolicy __all__ = [ @@ -13,7 +12,6 @@ "EnvFnMixin", "ModelSamplingMixin", "ModelTrainingMixin", - "TargetNetworksMixin", "TorchPolicy", "OptimizerCollection", ] diff --git a/raylab/policy/action_dist.py b/raylab/policy/action_dist.py index d6baeb0e..3a8434be 100644 --- a/raylab/policy/action_dist.py +++ b/raylab/policy/action_dist.py @@ -1,37 +1,132 @@ """Action distribution for compatibility with RLlib's interface.""" +from abc import ABCMeta +from abc import abstractmethod + +import torch.nn as nn from ray.rllib.models.action_dist import ActionDistribution -from ray.rllib.utils import override +from .modules.actor.policy.deterministic import DeterministicPolicy +from .modules.actor.policy.stochastic import StochasticPolicy +from .modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy as V0StochasticPi + + +class IncompatibleDistClsError(Exception): + """Exception raised for incompatible action distribution and NN module. + + Args: + dist_cls: Action distribution class + module: NN module + err: AssertionError explaining the reason why distribution and + module are incompatible + + Attributes: + message: Human-readable text explaining what caused the incompatibility + """ + + def __init__(self, dist_cls: type, module: nn.Module, err: Exception): + # pylint:disable=unused-argument + msg = ( + f"Action distribution type {dist_cls} is incompatible" + " with NN module of type {type(module)}. Reason:\n" + " {err}" + ) + super().__init__(msg) + self.message = msg + + +class BaseActionDist(ActionDistribution, metaclass=ABCMeta): + """Base class for TorchPolicy action distributions.""" -class WrapModuleDist(ActionDistribution): - """Stores a nn.Module and inputs, delegation all methods to the module.""" + @classmethod + def check_model_compat(cls, model: nn.Module): + """Assert the given NN module is compatible with the distribution. + + Raises: + IncompatibleDistClsError: If `model` is incompatible with the + distribution class + """ + try: + cls._check_model_compat(model) + except AssertionError as err: + raise IncompatibleDistClsError(cls, model, err) + + @classmethod + @abstractmethod + def _check_model_compat(cls, model: nn.Module): + pass + + +class WrapStochasticPolicy(BaseActionDist): + """Wraps an nn.Module with a stochastic actor and its inputs. + + Expects actor to be an instance of StochasticPolicy. + """ # pylint:disable=abstract-method + valid_actor_cls = (V0StochasticPi, StochasticPolicy) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._sampled_logp = None - @override(ActionDistribution) def sample(self): action, logp = self.model.actor.sample(**self.inputs) self._sampled_logp = logp return action, logp - @override(ActionDistribution) def deterministic_sample(self): - if hasattr(self.model.actor, "deterministic"): - return self.model.actor.deterministic(**self.inputs) - return self.model.actor(**self.inputs), None + return self.model.actor.deterministic(**self.inputs) - @override(ActionDistribution) def sampled_action_logp(self): return self._sampled_logp - @override(ActionDistribution) def logp(self, x): return self.model.actor.log_prob(value=x, **self.inputs) - @override(ActionDistribution) def entropy(self): return self.model.actor.entropy(**self.inputs) + + @classmethod + def _check_model_compat(cls, model): + assert hasattr(model, "actor"), f"NN model {type(model)} has no actor attribute" + assert isinstance(model.actor, cls.valid_actor_cls), ( + f"Expected actor to be an instance of {cls.valid_actor_cls};" + " found {type(model.actor)} instead." + ) + + +class WrapDeterministicPolicy(BaseActionDist): + """Wraps an nn.Module with a deterministic actor and its inputs. + + Expects actor to be an instance of DeterministicPolicy. + """ + + # pylint:disable=abstract-method + valid_actor_cls = valid_behavior_cls = DeterministicPolicy + + def sample(self): + action = self.model.behavior(**self.inputs) + return action, None + + def deterministic_sample(self): + return self.model.actor(**self.inputs), None + + def sampled_action_logp(self): + return None + + def logp(self, x): + return None + + @classmethod + def _check_model_compat(cls, model: nn.Module): + assert hasattr(model, "actor"), f"NN model {type(model)} has no actor attribute" + assert isinstance(model.actor, cls.valid_actor_cls), ( + f"Expected actor to be an instance of {cls.valid_actor_cls};" + " found {type(model.actor)} instead." + ) + + assert hasattr(model, "behavior"), "NN has no behavior attribute" + assert isinstance(model.actor, cls.valid_behavior_cls), ( + f"Expected behavior to be an instance of {cls.valid_behavior_cls};" + " found {type(model.behavior)} instead." + ) diff --git a/raylab/losses/__init__.py b/raylab/policy/losses/__init__.py similarity index 100% rename from raylab/losses/__init__.py rename to raylab/policy/losses/__init__.py diff --git a/raylab/losses/abstract.py b/raylab/policy/losses/abstract.py similarity index 100% rename from raylab/losses/abstract.py rename to raylab/policy/losses/abstract.py diff --git a/raylab/losses/cdq_learning.py b/raylab/policy/losses/cdq_learning.py similarity index 100% rename from raylab/losses/cdq_learning.py rename to raylab/policy/losses/cdq_learning.py diff --git a/raylab/losses/daml.py b/raylab/policy/losses/daml.py similarity index 100% rename from raylab/losses/daml.py rename to raylab/policy/losses/daml.py diff --git a/raylab/losses/isfv_iteration.py b/raylab/policy/losses/isfv_iteration.py similarity index 100% rename from raylab/losses/isfv_iteration.py rename to raylab/policy/losses/isfv_iteration.py diff --git a/raylab/losses/mage.py b/raylab/policy/losses/mage.py similarity index 100% rename from raylab/losses/mage.py rename to raylab/policy/losses/mage.py diff --git a/raylab/losses/mapo.py b/raylab/policy/losses/mapo.py similarity index 100% rename from raylab/losses/mapo.py rename to raylab/policy/losses/mapo.py diff --git a/raylab/losses/maximum_entropy.py b/raylab/policy/losses/maximum_entropy.py similarity index 100% rename from raylab/losses/maximum_entropy.py rename to raylab/policy/losses/maximum_entropy.py diff --git a/raylab/losses/mixins.py b/raylab/policy/losses/mixins.py similarity index 100% rename from raylab/losses/mixins.py rename to raylab/policy/losses/mixins.py diff --git a/raylab/losses/mle.py b/raylab/policy/losses/mle.py similarity index 97% rename from raylab/losses/mle.py rename to raylab/policy/losses/mle.py index 6a5524b3..f24d0114 100644 --- a/raylab/losses/mle.py +++ b/raylab/policy/losses/mle.py @@ -8,7 +8,7 @@ from ray.rllib import SampleBatch from torch import Tensor -from raylab.modules.mixins.stochastic_model_mixin import StochasticModel +from raylab.policy.modules.v0.mixins.stochastic_model_mixin import StochasticModel from raylab.utils.dictionaries import get_keys from .abstract import Loss diff --git a/raylab/losses/paml.py b/raylab/policy/losses/paml.py similarity index 100% rename from raylab/losses/paml.py rename to raylab/policy/losses/paml.py diff --git a/raylab/losses/policy_gradient.py b/raylab/policy/losses/policy_gradient.py similarity index 98% rename from raylab/losses/policy_gradient.py rename to raylab/policy/losses/policy_gradient.py index 4a37832e..548529a0 100644 --- a/raylab/losses/policy_gradient.py +++ b/raylab/policy/losses/policy_gradient.py @@ -8,7 +8,7 @@ from ray.rllib import SampleBatch from torch import Tensor -from raylab.modules.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.policy.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy from raylab.utils.annotations import DetPolicy from raylab.utils.annotations import DynamicsFn from raylab.utils.annotations import RewardFn diff --git a/raylab/losses/svg.py b/raylab/policy/losses/svg.py similarity index 97% rename from raylab/losses/svg.py rename to raylab/policy/losses/svg.py index a749feb9..dae05b4b 100644 --- a/raylab/losses/svg.py +++ b/raylab/policy/losses/svg.py @@ -10,8 +10,8 @@ from ray.rllib.utils import override from torch import Tensor -from raylab.modules.mixins.stochastic_actor_mixin import StochasticPolicy -from raylab.modules.mixins.stochastic_model_mixin import StochasticModel +from raylab.policy.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.policy.modules.v0.mixins.stochastic_model_mixin import StochasticModel from raylab.utils.annotations import RewardFn from raylab.utils.annotations import StateValue from raylab.utils.dictionaries import get_keys diff --git a/raylab/losses/utils.py b/raylab/policy/losses/utils.py similarity index 100% rename from raylab/losses/utils.py rename to raylab/policy/losses/utils.py diff --git a/raylab/policy/model_based/training_mixin.py b/raylab/policy/model_based/training_mixin.py index 20676270..c232bc5b 100644 --- a/raylab/policy/model_based/training_mixin.py +++ b/raylab/policy/model_based/training_mixin.py @@ -20,7 +20,7 @@ from torch.utils.data import DataLoader from torch.utils.data import RandomSampler -from raylab.losses.abstract import Loss +from raylab.policy.losses.abstract import Loss from raylab.pytorch.utils import TensorDictDataset diff --git a/raylab/modules/__init__.py b/raylab/policy/modules/__init__.py similarity index 100% rename from raylab/modules/__init__.py rename to raylab/policy/modules/__init__.py diff --git a/tests/agents/__init__.py b/raylab/policy/modules/actor/__init__.py similarity index 100% rename from tests/agents/__init__.py rename to raylab/policy/modules/actor/__init__.py diff --git a/raylab/policy/modules/actor/deterministic.py b/raylab/policy/modules/actor/deterministic.py new file mode 100644 index 00000000..d109c04a --- /dev/null +++ b/raylab/policy/modules/actor/deterministic.py @@ -0,0 +1,109 @@ +"""Network and configurations for modules with deterministic policies.""" +import warnings +from dataclasses import dataclass +from dataclasses import field + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .policy.deterministic import DeterministicPolicy +from .policy.deterministic import MLPDeterministicPolicy + +MLPSpec = MLPDeterministicPolicy.spec_cls + + +@dataclass +class DeterministicActorSpec(DataClassJsonMixin): + """Specifications for policy, behavior, and target policy. + + Args: + encoder: Specifications for creating the multilayer perceptron mapping + states to pre-action linear features + norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't + normalize actions before squashing function + separate_behavior: Whether to create a separate behavior policy. Usually + for parameter noise exploration, in which case it is recommended to + enable encoder.layer_norm alongside this option. + smooth_target_policy: Whether to use a noisy target policy for + Q-Learning + target_gaussian_sigma: Gaussian standard deviation for noisy target + policy + separate_target_policy: Whether to use separate parameters for the + target policy. Intended for use with polyak averaging + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + + encoder: MLPSpec = field(default_factory=MLPSpec) + norm_beta: float = 1.2 + separate_behavior: bool = False + smooth_target_policy: bool = True + target_gaussian_sigma: float = 0.3 + separate_target_policy: bool = False + initializer: dict = field(default_factory=dict) + + def __post_init__(self): + cls_name = type(self).__name__ + assert self.norm_beta > 0, f"{cls_name}.norm_beta must be positive" + assert ( + self.target_gaussian_sigma > 0 + ), f"{cls_name}.target_gaussian_sigma must be positive" + + +class DeterministicActor(nn.Module): + """NN with deterministic policies. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for policy, behavior, and target policy + + Attributes: + policy: The deterministic policy to be learned + behavior: The policy for exploration. `utils.exploration.GaussianNoise` + handles Gaussian action noise exploration separatedly. + target_policy: The policy used for estimating the arg max in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = DeterministicActorSpec + + def __init__( + self, obs_space: Box, action_space: Box, spec: DeterministicActorSpec, + ): + super().__init__() + + def make_policy(): + return MLPDeterministicPolicy( + obs_space, action_space, spec.encoder, spec.norm_beta + ) + + policy = make_policy() + policy.initialize_parameters(spec.initializer) + + behavior = policy + if spec.separate_behavior: + if not spec.encoder.layer_norm: + warnings.warn( + "Separate behavior policy requested and layer normalization" + " deactivated. If using parameter noise exploration, enable" + " layer normalization for better stability." + ) + behavior = make_policy() + behavior.load_state_dict(policy.state_dict()) + + target_policy = policy + if spec.separate_target_policy: + target_policy = make_policy() + target_policy.load_state_dict(policy.state_dict()) + if spec.smooth_target_policy: + target_policy = DeterministicPolicy.add_gaussian_noise( + target_policy, noise_stddev=spec.target_gaussian_sigma + ) + + self.policy = policy + self.behavior = behavior + self.target_policy = target_policy diff --git a/tests/agents/mage/__init__.py b/raylab/policy/modules/actor/policy/__init__.py similarity index 100% rename from tests/agents/mage/__init__.py rename to raylab/policy/modules/actor/policy/__init__.py diff --git a/raylab/policy/modules/actor/policy/deterministic.py b/raylab/policy/modules/actor/policy/deterministic.py new file mode 100644 index 00000000..b9d50bd1 --- /dev/null +++ b/raylab/policy/modules/actor/policy/deterministic.py @@ -0,0 +1,137 @@ +"""Parameterized deterministic policies.""" +import warnings +from typing import Optional + +import torch +import torch.nn as nn +from gym.spaces import Box +from torch import Tensor + +import raylab.pytorch.nn as nnx +from raylab.policy.modules.networks.mlp import StateMLP + + +class DeterministicPolicy(nn.Module): + """Continuous action deterministic policy as a sequence of modules. + + If a noise module is passed, it is evaluated on unconstrained actions before + the squashing module. + + Args: + encoder: NN module mapping states to 1D features + action_linear: Linear module mapping features to unconstrained actions + squashing: Invertible module mapping unconstrained actions to bounded + action space + noise: Optional stochastic module adding noise to unconstrained actions + """ + + def __init__( + self, + encoder: nn.Module, + action_linear: nn.Module, + squashing: nn.Module, + noise: Optional[nn.Module] = None, + ): + super().__init__() + self.encoder = encoder + self.action_linear = action_linear + self.squashing = squashing + self.noise = noise + + def forward(self, obs: Tensor) -> Tensor: # pylint:disable=arguments-differ + """Main forward pass mapping observations to actions.""" + unconstrained_action = self.unconstrained_action(obs) + return self.squashing(unconstrained_action) + + @torch.jit.export + def unconstrained_action(self, obs: Tensor) -> Tensor: + """Forward pass with no squashing at the end""" + features = self.encoder(obs) + unconstrained_action = self.action_linear(features) + if self.noise is not None: + unconstrained_action = self.noise(unconstrained_action) + return unconstrained_action + + @torch.jit.export + def squash_action(self, unconstrained_action: Tensor) -> Tensor: + """Returns the action generated by the given unconstrained action.""" + return self.squashing(unconstrained_action) + + @torch.jit.export + def unsquash_action(self, action: Tensor) -> Tensor: + """Returns the unconstrained action which generated the given action.""" + return self.squashing(action, reverse=True) + + @classmethod + def add_gaussian_noise(cls, policy, noise_stddev: float): + """Adds a zero-mean Gaussian noise module to a DeterministicPolicy. + + Args: + policy: The deterministic policy. + noise_stddev: Standard deviation of the Gaussian noise + + Returns: + A deterministic policy sharing all paremeters with the input one and + with additional noise module before squashing. + """ + if policy.noise is not None: + warnings.warn( + "Adding Gaussian noise to already noisy policy. Are you sure you" + " called `add_gaussian_noise` on the right policy?" + ) + noise = nnx.GaussianNoise(noise_stddev) + return cls(policy.encoder, policy.action_linear, policy.squashing, noise=noise) + + +class MLPDeterministicPolicy(DeterministicPolicy): + """DeterministicPolicy with multilayer perceptron encoder. + + The final Linear layer is initialized so that actions are near the origin + point. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Multilayer perceptron specifications + norm_beta: Maximum l1 norm of the unconstrained actions. If None, won't + normalize actions before squashing function. + + Attributes: + spec_cls: Expected class of `spec` init argument + """ + + spec_cls = StateMLP.spec_cls + + def __init__( + self, + obs_space: Box, + action_space: Box, + mlp_spec: StateMLP.spec_cls, + norm_beta: float, + ): + encoder = StateMLP(obs_space, mlp_spec) + + action_size = action_space.shape[0] + if norm_beta: + action_linear = nnx.NormalizedLinear( + encoder.out_features, action_size, beta=norm_beta + ) + else: + action_linear = nn.Linear(encoder.out_features, action_size) + + action_low, action_high = map( + torch.as_tensor, (action_space.low, action_space.high) + ) + squash = nnx.TanhSquash(action_low, action_high) + + super().__init__(encoder, action_linear, squash) + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Args: + initializer_spec: Dictionary with mandatory `name` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + self.encoder.initialize_parameters(initializer_spec) diff --git a/raylab/policy/modules/actor/policy/stochastic.py b/raylab/policy/modules/actor/policy/stochastic.py new file mode 100644 index 00000000..944d9983 --- /dev/null +++ b/raylab/policy/modules/actor/policy/stochastic.py @@ -0,0 +1,214 @@ +"""Parameterized stochastic policies.""" +from typing import Callable +from typing import List + +import torch +import torch.nn as nn +from gym.spaces import Box +from gym.spaces import Discrete + +import raylab.pytorch.nn as nnx +import raylab.pytorch.nn.distributions as ptd +from raylab.policy.modules.networks.mlp import StateMLP + + +class StochasticPolicy(nn.Module): + """Represents a stochastic policy as a conditional distribution module.""" + + # pylint:disable=abstract-method + + def __init__(self, params_module, dist_module): + super().__init__() + self.params = params_module + self.dist = dist_module + + def forward(self, obs): # pylint:disable=arguments-differ + return self.params(obs) + + @torch.jit.export + def sample(self, obs, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. Returns a (sample, log_prob) + pair. + """ + params = self(obs) + return self.dist.sample(params, sample_shape) + + @torch.jit.export + def rsample(self, obs, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. Returns a (rsample, log_prob) pair. + """ + params = self(obs) + return self.dist.rsample(params, sample_shape) + + @torch.jit.export + def log_prob(self, obs, action): + """ + Returns the log of the probability density/mass function evaluated at `action`. + """ + params = self(obs) + return self.dist.log_prob(action, params) + + @torch.jit.export + def cdf(self, obs, action): + """Returns the cumulative density/mass function evaluated at `action`.""" + params = self(obs) + return self.dist.cdf(action, params) + + @torch.jit.export + def icdf(self, obs, prob): + """Returns the inverse cumulative density/mass function evaluated at `prob`.""" + params = self(obs) + return self.dist.icdf(prob, params) + + @torch.jit.export + def entropy(self, obs): + """Returns entropy of distribution.""" + params = self(obs) + return self.dist.entropy(params) + + @torch.jit.export + def perplexity(self, obs): + """Returns perplexity of distribution.""" + params = self(obs) + return self.dist.perplexity(params) + + @torch.jit.export + def reproduce(self, obs, action): + """Produce a reparametrized sample with the same value as `action`.""" + params = self(obs) + return self.dist.reproduce(action, params) + + @torch.jit.export + def deterministic(self, obs): + """ + Generates a deterministic sample or batch of samples if the distribution + parameters are batched. Returns a (rsample, log_prob) pair. + """ + params = self(obs) + return self.dist.deterministic(params) + + +class MLPStochasticPolicy(StochasticPolicy): + """Stochastic policy with multilayer perceptron state encoder. + + Args: + obs_space: Observation space + spec: Specifications for the encoder + params_fn: Callable that builds a module for computing distribution + parameters given the number of state features + dist: Conditional distribution module + + Attributes: + encoder: Multilayer perceptron state encoder + spec: MLP spec instance + """ + + spec_cls = StateMLP.spec_cls + + def __init__( + self, + obs_space: Box, + spec: StateMLP.spec_cls, + params_fn: Callable[[int], nn.Module], + dist: ptd.ConditionalDistribution, + ): + encoder = StateMLP(obs_space, spec) + params = params_fn(encoder.out_features) + params_module = nn.Sequential(encoder, params) + super().__init__(params_module, dist) + + self.encoder = encoder + self.spec = spec + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Args: + initializer_spec: Dictionary with mandatory `name` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + self.encoder.initialize_parameters(initializer_spec) + + +class MLPContinuousPolicy(MLPStochasticPolicy): + """Multilayer perceptron policy for continuous actions. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Specifications for the multilayer perceptron + input_dependent_scale: Whether to parameterize the Gaussian standard + deviation as a function of the state + """ + + def __init__( + self, + obs_space: Box, + action_space: Box, + mlp_spec: MLPStochasticPolicy.spec_cls, + input_dependent_scale: bool, + ): + def params_fn(out_features): + return nnx.NormalParams( + out_features, + action_space.shape[0], + input_dependent_scale=input_dependent_scale, + ) + + dist = ptd.TransformedDistribution( + ptd.Independent(ptd.Normal(), reinterpreted_batch_ndims=1), + ptd.flows.TanhSquashTransform( + low=torch.as_tensor(action_space.low), + high=torch.as_tensor(action_space.high), + event_dim=1, + ), + ) + super().__init__(obs_space, mlp_spec, params_fn, dist) + + +class MLPDiscretePolicy(MLPStochasticPolicy): + """Multilayer perceptron policy for discrete actions. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Specifications for the multilayer perceptron + """ + + def __init__( + self, + obs_space: Box, + action_space: Discrete, + mlp_spec: MLPStochasticPolicy.spec_cls, + ): + def params_fn(out_features): + return nnx.CategoricalParams(out_features, action_space.n) + + dist = ptd.Categorical() + super().__init__(obs_space, mlp_spec, params_fn, dist) + + +class Alpha(nn.Module): + """Wraps a single scalar coefficient parameter. + + Allows learning said coefficient by having it as a parameter + + Args: + initial_alpha: Value to initialize the coefficient to + + Attributes: + lob_alpha: Natural logarithm of the current coefficient + """ + + def __init__(self, initial_alpha: float): + super().__init__() + self.log_alpha = nn.Parameter(torch.empty([]).fill_(initial_alpha).log()) + + def forward(self): # pylint:disable=arguments-differ + return self.log_alpha.exp() diff --git a/raylab/policy/modules/actor/stochastic.py b/raylab/policy/modules/actor/stochastic.py new file mode 100644 index 00000000..088975e1 --- /dev/null +++ b/raylab/policy/modules/actor/stochastic.py @@ -0,0 +1,83 @@ +"""Network and configurations for modules with stochastic policies.""" +import warnings +from dataclasses import dataclass +from dataclasses import field +from typing import Union + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box +from gym.spaces import Discrete + +from .policy.stochastic import Alpha +from .policy.stochastic import MLPContinuousPolicy +from .policy.stochastic import MLPDiscretePolicy +from .policy.stochastic import MLPStochasticPolicy + +MLPSpec = MLPStochasticPolicy.spec_cls + + +@dataclass +class StochasticActorSpec(DataClassJsonMixin): + """Specifications for stochastic policy and entropy coefficient. + + Args: + encoder: Specifications for building the multilayer perceptron state + processor + input_dependent_scale: Whether to parameterize the Gaussian standard + deviation as a function of the state + initial_entropy_coeff: Optional initial value of the entropy bonus term. + The actor creates an `alpha` attribute with this initial value. + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + + encoder: MLPSpec = field(default_factory=MLPSpec) + input_dependent_scale: bool = False + initial_entropy_coeff: float = 0.0 + initializer: dict = field(default_factory=dict) + + def __post_init__(self): + cls_name = type(self).__name__ + ent_coeff = self.initial_entropy_coeff + if ent_coeff < 0: + warnings.warn(f"Entropy coefficient is negative in {cls_name}: {ent_coeff}") + + +class StochasticActor(nn.Module): + """NN with stochastic policy. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for stochastic policy + + Attributes: + policy: Stochastic policy to be learned + alpha: Entropy bonus coefficient + """ + + # pylint:disable=abstract-method + spec_cls = StochasticActorSpec + + def __init__( + self, + obs_space: Box, + action_space: Union[Box, Discrete], + spec: StochasticActorSpec, + ): + super().__init__() + + if isinstance(action_space, Box): + policy = MLPContinuousPolicy( + obs_space, action_space, spec.encoder, spec.input_dependent_scale + ) + elif isinstance(action_space, Discrete): + policy = MLPDiscretePolicy(obs_space, action_space, spec.encoder) + else: + raise ValueError(f"Unsopported action space type {type(action_space)}") + policy.initialize_parameters(spec.initializer) + + self.policy = policy + self.alpha = Alpha(spec.initial_entropy_coeff) diff --git a/raylab/policy/modules/catalog.py b/raylab/policy/modules/catalog.py new file mode 100644 index 00000000..7975035b --- /dev/null +++ b/raylab/policy/modules/catalog.py @@ -0,0 +1,28 @@ +"""Registry of modules for PyTorch policies.""" +import torch.nn as nn +from gym.spaces import Space + +from .ddpg import DDPG +from .mbddpg import MBDDPG +from .mbsac import MBSAC +from .sac import SAC +from .v0.catalog import get_module as get_v0_module + +MODULES = {cls.__name__: cls for cls in (DDPG, MBDDPG, MBSAC, SAC)} + + +def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: + """Retrieve and construct module of given name. + + Args: + obs_space: Observation space + action_space: Action space + config: Configurations for module construction and initialization + """ + type_ = config.pop("type") + if type_ not in MODULES: + return get_v0_module(obs_space, action_space, {"type": type_, **config}) + + cls = MODULES[type_] + spec = cls.spec_cls.from_dict(config) + return cls(obs_space, action_space, spec) diff --git a/tests/agents/mapo/__init__.py b/raylab/policy/modules/critic/__init__.py similarity index 100% rename from tests/agents/mapo/__init__.py rename to raylab/policy/modules/critic/__init__.py diff --git a/raylab/policy/modules/critic/action_value.py b/raylab/policy/modules/critic/action_value.py new file mode 100644 index 00000000..9953a8a3 --- /dev/null +++ b/raylab/policy/modules/critic/action_value.py @@ -0,0 +1,81 @@ +"""Network and configurations for modules with Q-value critics.""" +from dataclasses import dataclass +from dataclasses import field + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .q_value import ForkedQValueEnsemble +from .q_value import MLPQValue +from .q_value import QValueEnsemble + + +QValueSpec = MLPQValue.spec_cls + + +@dataclass +class ActionValueCriticSpec(DataClassJsonMixin): + """Specifications for action-value estimators. + + Args: + encoder: Specifications for creating the multilayer perceptron mapping + states and actions to pre-value function linear features + double_q: Whether to create two Q-value estimators instead of one. + Defaults to True + parallelize: Whether to evaluate Q-values in parallel. Defaults to + False. + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + + encoder: QValueSpec = field(default_factory=QValueSpec) + double_q: bool = True + parallelize: bool = False + initializer: dict = field(default_factory=dict) + + +class ActionValueCritic(nn.Module): + """NN with Q-value estimators. + + Since it is common to use clipped double Q-Learning, `q_values` is a + ModuleList of Q-value functions. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for action-value estimators + + Attributes: + q_values: The action-value estimators to be learned + target_q_values: The action-value estimators used for bootstrapping in + Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = ActionValueCriticSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: ActionValueCriticSpec): + super().__init__() + + def make_q_value(): + return MLPQValue(obs_space, action_space, spec.encoder) + + def make_q_value_ensemble(): + n_q_values = 2 if spec.double_q else 1 + q_values = [make_q_value() for _ in range(n_q_values)] + + if spec.parallelize: + return ForkedQValueEnsemble(q_values) + return QValueEnsemble(q_values) + + q_values = make_q_value_ensemble() + q_values.initialize_parameters(spec.initializer) + + target_q_values = make_q_value_ensemble() + target_q_values.load_state_dict(q_values.state_dict()) + + self.q_values = q_values + self.target_q_values = target_q_values diff --git a/raylab/policy/modules/critic/q_value.py b/raylab/policy/modules/critic/q_value.py new file mode 100644 index 00000000..1059f472 --- /dev/null +++ b/raylab/policy/modules/critic/q_value.py @@ -0,0 +1,120 @@ +"""Parameterized action-value estimators.""" +import torch +import torch.nn as nn +from gym.spaces import Box +from torch import Tensor + +from raylab.policy.modules.networks.mlp import StateActionMLP + + +MLPSpec = StateActionMLP.spec_cls + + +class QValue(nn.Module): + """Neural network module emulating a Q value function. + + Args: + encoder: NN module mapping states to 1D features. Must have an + `out_features` attribute with the size of the output features + """ + + def __init__(self, encoder: nn.Module): + super().__init__() + self.encoder = encoder + self.value_linear = nn.Linear(self.encoder.out_features, 1) + + def forward(self, obs: Tensor, action: Tensor) -> Tensor: + """Main forward pass mapping obs and actions to Q-values. + + Note: + The output tensor has a last singleton dimension, i.e., for a batch + of 10 obs-action pairs, the output will have shape (10, 1). + """ + # pylint:disable=arguments-differ + features = self.encoder(obs, action) + return self.value_linear(features) + + +class MLPQValue(QValue): + """Q-value function with a multilayer perceptron encoder. + + Args: + obs_space: Observation space + action_space: Action space + mlp_spec: Multilayer perceptron specifications + """ + + spec_cls = MLPSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: MLPSpec): + encoder = StateActionMLP(obs_space, action_space, spec) + super().__init__(encoder) + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `name` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + self.encoder.initialize_parameters(initializer_spec) + + +class QValueEnsemble(nn.ModuleList): + """A static list of Q-value estimators. + + Args: + q_values: A list of QValue modules + """ + + def __init__(self, q_values): + cls_name = type(self).__name__ + assert all( + isinstance(q, QValue) for q in q_values + ), f"All modules in {cls_name} must be instances of QValue." + super().__init__(q_values) + + def forward(self, obs: Tensor, action: Tensor, clip: bool = False) -> Tensor: + """Evaluate each Q estimator in the ensemble. + + Args: + obs: The observation tensor + action: The action tensor + clip: Whether to output the minimum of the action-values. Preserves + output dimensions + + Returns: + A tensor of shape `(*, N)`, where `N` is the ensemble size + """ + # pylint:disable=arguments-differ + action_values = torch.cat([m(obs, action) for m in self], dim=-1) + if clip: + action_values, _ = action_values.min(keepdim=True, dim=-1) + return action_values + + def initialize_parameters(self, initializer_spec: dict): + """Initialize each Q estimator in the ensemble. + + Args: + initializer_spec: Dictionary with mandatory `name` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + for q_value in self: + q_value.initialize_parameters(initializer_spec) + + +class ForkedQValueEnsemble(QValueEnsemble): + """Ensemble of Q-value estimators with parallelized forward pass.""" + + def forward(self, obs: Tensor, action: Tensor, clip: bool = False) -> Tensor: + # pylint:disable=protected-access + futures = [torch.jit._fork(m, obs, action) for m in self] + action_values = torch.cat([torch.jit._wait(f) for f in futures], dim=-1) + if clip: + action_values, _ = action_values.min(keepdim=True, dim=-1) + return action_values diff --git a/raylab/policy/modules/ddpg.py b/raylab/policy/modules/ddpg.py new file mode 100644 index 00000000..6c114be4 --- /dev/null +++ b/raylab/policy/modules/ddpg.py @@ -0,0 +1,74 @@ +"""NN architecture used in Deep Deterministic Policy Gradients.""" +from dataclasses import dataclass +from dataclasses import field + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .actor.deterministic import DeterministicActor +from .critic.action_value import ActionValueCritic + +ActorSpec = DeterministicActor.spec_cls +CriticSpec = ActionValueCritic.spec_cls + + +@dataclass +class DDPGSpec(DataClassJsonMixin): + """Specifications for DDPG modules. + + Args: + actor: Specifications for policy, behavior, and target policy + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Overrides actor and critic initializer + specifications. + """ + + actor: ActorSpec = field(default_factory=ActorSpec) + critic: CriticSpec = field(default_factory=CriticSpec) + initializer: dict = field(default_factory=dict) + + def __post_init__(self): + # Top-level initializer options take precedence over individual + # component's options + if self.initializer: + self.actor.initializer = self.initializer + self.critic.initializer = self.initializer + + +class DDPG(nn.Module): + """NN module for DDPG-like algorithms. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for DDPG modules + + Attributes: + actor (DeterministicPolicy): The deterministic policy to be learned + behavior (DeterministicPolicy): The policy for exploration + target_actor (DeterministicPolicy): The policy used for estimating the + arg max in Q-Learning + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = DDPGSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: DDPGSpec): + super().__init__() + # Build actor + actor = DeterministicActor(obs_space, action_space, spec.actor) + self.actor = actor.policy + self.behavior = actor.behavior + self.target_actor = actor.target_policy + + # Build critic + critic = ActionValueCritic(obs_space, action_space, spec.critic) + self.critics = critic.q_values + self.target_critics = critic.target_q_values diff --git a/raylab/policy/modules/mbddpg.py b/raylab/policy/modules/mbddpg.py new file mode 100644 index 00000000..931dc338 --- /dev/null +++ b/raylab/policy/modules/mbddpg.py @@ -0,0 +1,61 @@ +"""Network and configurations for model-based DDPG algorithms.""" +from dataclasses import dataclass +from dataclasses import field + +from gym.spaces import Box + +from .ddpg import DDPG +from .ddpg import DDPGSpec +from .model.stochastic import build_ensemble +from .model.stochastic import EnsembleSpec + + +@dataclass +class MBDDPGSpec(DDPGSpec): + """Specifications for model-based DDPG modules. + + Args: + model: Specifications for stochastic dynamics model ensemble + actor: Specifications for policy, behavior, and target policy + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Overrides model, actor, and critic initializer + specifications. + """ + + model: EnsembleSpec = field(default_factory=EnsembleSpec) + + def __post_init__(self): + super().__post_init__() + if self.initializer: + self.model.initializer = self.initializer + + +class MBDDPG(DDPG): + """NN module for Model-Based DDPG algorithms. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for model-based DDPG modules + + Attributes: + model (StochasticModelEnsemble): Stochastic dynamics model ensemble + actor (DeterministicPolicy): The deterministic policy to be learned + behavior (DeterministicPolicy): The policy for exploration + target_actor (DeterministicPolicy): The policy used for estimating the + arg max in Q-Learning + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = MBDDPGSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: MBDDPGSpec): + super().__init__(obs_space, action_space, spec) + + self.models = build_ensemble(obs_space, action_space, spec.model) diff --git a/raylab/policy/modules/mbsac.py b/raylab/policy/modules/mbsac.py new file mode 100644 index 00000000..dc7c897c --- /dev/null +++ b/raylab/policy/modules/mbsac.py @@ -0,0 +1,59 @@ +"""Network and configurations for model-based SAC algorithms.""" +from dataclasses import dataclass +from dataclasses import field + +from gym.spaces import Box + +from .model.stochastic import build_ensemble +from .model.stochastic import EnsembleSpec +from .sac import SAC +from .sac import SACSpec + + +@dataclass +class MBSACSpec(SACSpec): + """Specifications for model-based SAC modules. + + Args: + model: Specifications for stochastic dynamics model ensemble + actor: Specifications for stochastic policy and entropy coefficient + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Overrides model, actor, and critic initializer + specifications. + """ + + model: EnsembleSpec = field(default_factory=EnsembleSpec) + + def __post_init__(self): + super().__post_init__() + if self.initializer: + self.model.initializer = self.initializer + + +class MBSAC(SAC): + """NN module for Model-Based Soft Actor-Critic algorithms. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for model-based SAC modules + + Attributes: + models (StochasticModelEnsemble): Stochastic dynamics model ensemble + actor (StochasticPolicy): Stochastic policy to be learned + alpha (Alpha): Entropy bonus coefficient + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = MBSACSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: MBSACSpec): + super().__init__(obs_space, action_space, spec) + + self.models = build_ensemble(obs_space, action_space, spec.model) diff --git a/tests/agents/mbpo/__init__.py b/raylab/policy/modules/model/__init__.py similarity index 100% rename from tests/agents/mbpo/__init__.py rename to raylab/policy/modules/model/__init__.py diff --git a/raylab/policy/modules/model/stochastic/__init__.py b/raylab/policy/modules/model/stochastic/__init__.py new file mode 100644 index 00000000..342df363 --- /dev/null +++ b/raylab/policy/modules/model/stochastic/__init__.py @@ -0,0 +1,4 @@ +"""Implementations of stochastic dynamics models.""" + +from .builders import build_ensemble +from .builders import EnsembleSpec diff --git a/raylab/policy/modules/model/stochastic/builders.py b/raylab/policy/modules/model/stochastic/builders.py new file mode 100644 index 00000000..aa90d52a --- /dev/null +++ b/raylab/policy/modules/model/stochastic/builders.py @@ -0,0 +1,90 @@ +"""Constructors for stochastic dynamics models.""" +from dataclasses import dataclass +from dataclasses import field + +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .ensemble import ForkedStochasticModelEnsemble +from .ensemble import StochasticModelEnsemble +from .single import MLPModel +from .single import ResidualMLPModel + +ModelSpec = MLPModel.spec_cls + + +@dataclass +class Spec(DataClassJsonMixin): + """Specifications for stochastic dynamics model. + + Args: + network: Specifications for stochastic model network + input_dependent_scale: Whether to parameterize the Gaussian standard + deviation as a function of the state and action + residual: Whether to build model as a residual one, i.e., that + predicts the change in state rather than the next state itself + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Used to initialize the model's Linear layers. + """ + + network: ModelSpec = field(default_factory=ModelSpec) + input_dependent_scale: bool = True + residual: bool = True + initializer: dict = field(default_factory=dict) + + +def build(obs_space: Box, action_space: Box, spec: Spec) -> MLPModel: + """Construct stochastic dynamics model. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for stochastic dynamics model + + Returns: + A stochastic dynamics model + """ + cls = ResidualMLPModel if spec.residual else MLPModel + model = cls(obs_space, action_space, spec.network, spec.input_dependent_scale) + model.initialize_parameters(spec.initializer) + return model + + +@dataclass +class EnsembleSpec(Spec): + """Specifications for stochastic dynamics model ensemble. + + Args: + network: Specifications for stochastic model networks + ensemble_size: Number of models in the collection. + residual: Whether to build each model as a residual one, i.e., that + predicts the change in state rather than the next state itself + parallelize: Whether to use an ensemble with parallelized `sample`, + `rsample`, and `log_prob` methods + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Used to initialize the models' Linear layers. + """ + + ensemble_size: int = 1 + parallelize: bool = False + + +def build_ensemble( + obs_space: Box, action_space: Box, spec: EnsembleSpec +) -> StochasticModelEnsemble: + """Construct stochastic dynamics model ensemble. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for stochastic dynamics model ensemble + + Returns: + A stochastic dynamics model ensemble + """ + models = [build(obs_space, action_space, spec) for _ in range(spec.ensemble_size)] + cls = ForkedStochasticModelEnsemble if spec.parallelize else StochasticModelEnsemble + ensemble = cls(models) + return ensemble diff --git a/raylab/policy/modules/model/stochastic/ensemble.py b/raylab/policy/modules/model/stochastic/ensemble.py new file mode 100644 index 00000000..dae46a2a --- /dev/null +++ b/raylab/policy/modules/model/stochastic/ensemble.py @@ -0,0 +1,72 @@ +"""Network and configurations for modules with stochastic model ensembles.""" +from typing import List + +import torch +import torch.nn as nn + +from .single import StochasticModel + + +class StochasticModelEnsemble(nn.ModuleList): + """A static list of stochastic dynamics models. + + Args: + models: List of StochasticModel modules + """ + + # pylint:disable=abstract-method + + def __init__(self, models: List[StochasticModel]): + cls_name = type(self).__name__ + assert all( + isinstance(m, StochasticModel) for m in models + ), f"All modules in {cls_name} must be instances of StochasticModel." + super().__init__(models) + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + """Compute samples and likelihoods for each model in the ensemble.""" + outputs = [m.sample(obs, action, sample_shape) for m in self] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + """Compute reparemeterized samples and likelihoods for each model.""" + outputs = [m.rsample(obs, action, sample_shape) for m in self] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + """Compute likelihoods for each model in the ensemble.""" + return torch.stack([m.log_prob(obs, action, next_obs) for m in self]) + + +class ForkedStochasticModelEnsemble(StochasticModelEnsemble): + """Ensemble of stochastic models with parallelized methods.""" + + # pylint:disable=abstract-method,protected-access + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + futures = [torch.jit._fork(m.sample, obs, action, sample_shape) for m in self] + outputs = [torch.jit._wait(f) for f in futures] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + futures = [torch.jit._fork(m.rsample, obs, action, sample_shape) for m in self] + outputs = [torch.jit._wait(f) for f in futures] + sample = torch.stack([s for s, _ in outputs]) + logp = torch.stack([p for _, p in outputs]) + return sample, logp + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + futures = [torch.jit._fork(m.log_prob, obs, action, next_obs) for m in self] + return torch.stack([torch.jit._wait(f) for f in futures]) diff --git a/raylab/policy/modules/model/stochastic/single.py b/raylab/policy/modules/model/stochastic/single.py new file mode 100644 index 00000000..0b214159 --- /dev/null +++ b/raylab/policy/modules/model/stochastic/single.py @@ -0,0 +1,180 @@ +"""NN modules for stochastic dynamics estimation.""" +from typing import List + +import torch +import torch.nn as nn +from gym.spaces import Box + +import raylab.pytorch.nn as nnx +import raylab.pytorch.nn.distributions as ptd +from raylab.policy.modules.networks.mlp import StateActionMLP + + +class StochasticModel(nn.Module): + """Represents a stochastic model as a conditional distribution module.""" + + def __init__( + self, params_module: nn.Module, dist_module: ptd.ConditionalDistribution + ): + super().__init__() + self.params = params_module + self.dist = dist_module + + def forward(self, obs, action): # pylint:disable=arguments-differ + return self.params(obs, action) + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. Returns a (sample, log_prob) + pair. + """ + params = self(obs, action) + return self.dist.sample(params, sample_shape) + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. Returns a (rsample, log_prob) pair. + """ + params = self(obs, action) + return self.dist.rsample(params, sample_shape) + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + """ + Returns the log probability density/mass function evaluated at `next_obs`. + """ + params = self(obs, action) + return self.dist.log_prob(next_obs, params) + + @torch.jit.export + def cdf(self, obs, action, next_obs): + """Returns the cumulative density/mass function evaluated at `next_obs`.""" + params = self(obs, action) + return self.dist.cdf(next_obs, params) + + @torch.jit.export + def icdf(self, obs, action, prob): + """Returns the inverse cumulative density/mass function evaluated at `prob`.""" + params = self(obs, action) + return self.dist.icdf(prob, params) + + @torch.jit.export + def entropy(self, obs, action): + """Returns entropy of distribution.""" + params = self(obs, action) + return self.dist.entropy(params) + + @torch.jit.export + def perplexity(self, obs, action): + """Returns perplexity of distribution.""" + params = self(obs, action) + return self.dist.perplexity(params) + + @torch.jit.export + def reproduce(self, obs, action, next_obs): + """Produce a reparametrized sample with the same value as `next_obs`.""" + params = self(obs, action) + return self.dist.reproduce(next_obs, params) + + +class ResidualMixin: + """Overrides StochasticModel interface to model state transition residuals.""" + + # pylint:disable=missing-function-docstring,not-callable + + @torch.jit.export + def sample(self, obs, action, sample_shape: List[int] = ()): + params = self(obs, action) + res, log_prob = self.dist.sample(params, sample_shape) + return obs + res, log_prob + + @torch.jit.export + def rsample(self, obs, action, sample_shape: List[int] = ()): + params = self(obs, action) + res, log_prob = self.dist.rsample(params, sample_shape) + return obs + res, log_prob + + @torch.jit.export + def log_prob(self, obs, action, next_obs): + params = self(obs, action) + return self.dist.log_prob(next_obs - obs, params) + + @torch.jit.export + def cdf(self, obs, action, next_obs): + params = self(obs, action) + return self.dist.cdf(next_obs - obs, params) + + @torch.jit.export + def icdf(self, obs, action, prob): + params = self(obs, action) + residual = self.dist.icdf(prob, params) + return obs + residual + + @torch.jit.export + def reproduce(self, obs, action, next_obs): + params = self(obs, action) + sample_, log_prob_ = self.dist.reproduce(next_obs - obs, params) + return obs + sample_, log_prob_ + + +class DynamicsParams(nn.Module): + """Neural network mapping state-action pairs to distribution parameters. + + Args: + encoder: Module mapping state-action pairs to 1D features + params: Module mapping 1D features to distribution parameters + """ + + def __init__(self, encoder: nn.Module, params: nn.Module): + super().__init__() + self.encoder = encoder + self.params = params + + def forward(self, obs, actions): # pylint:disable=arguments-differ + return self.params(self.encoder(obs, actions)) + + +MLPSpec = StateActionMLP.spec_cls + + +class MLPModel(StochasticModel): + """Stochastic model with multilayer perceptron state-action encoder.""" + + spec_cls = MLPSpec + + def __init__( + self, + obs_space: Box, + action_space: Box, + spec: MLPSpec, + input_dependent_scale: bool, + ): + encoder = StateActionMLP(obs_space, action_space, spec) + params = nnx.NormalParams( + encoder.out_features, + obs_space.shape[0], + input_dependent_scale=input_dependent_scale, + ) + params = DynamicsParams(encoder, params) + dist = ptd.Independent(ptd.Normal(), reinterpreted_batch_ndims=1) + super().__init__(params, dist) + self.encoder = encoder + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all encoder parameters. + + Args: + initializer_spec: Dictionary with mandatory `name` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + self.encoder.initialize_parameters(initializer_spec) + + +class ResidualMLPModel(ResidualMixin, MLPModel): + """Residual stochastic multilayer perceptron model.""" diff --git a/raylab/modules/networks/__init__.py b/raylab/policy/modules/networks/__init__.py similarity index 100% rename from raylab/modules/networks/__init__.py rename to raylab/policy/modules/networks/__init__.py diff --git a/raylab/policy/modules/networks/mlp.py b/raylab/policy/modules/networks/mlp.py new file mode 100644 index 00000000..859060db --- /dev/null +++ b/raylab/policy/modules/networks/mlp.py @@ -0,0 +1,156 @@ +# pylint:disable=missing-module-docstring +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import List +from typing import Optional + +import torch +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +import raylab.pytorch.nn as nnx +from raylab.pytorch.nn.init import initialize_ +from raylab.pytorch.nn.utils import get_activation + + +@dataclass +class StateMLPSpec(DataClassJsonMixin): + """Specifications for creating a multilayer perceptron. + + Args: + units: Number of units in each hidden layer + activation: Nonlinearity following each linear layer + layer_norm: Whether to apply layer normalization between each linear layer + and following activation + """ + + units: List[int] = field(default_factory=list) + activation: Optional[str] = None + layer_norm: bool = False + + +class StateMLP(nnx.FullyConnected): + """Multilayer perceptron for encoding state inputs.""" + + spec_cls = StateMLPSpec + + def __init__(self, obs_space: Box, spec: StateMLPSpec): + obs_size = obs_space.shape[0] + super().__init__( + obs_size, spec.units, spec.activation, layer_norm=spec.layer_norm, + ) + self.spec = spec + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `name` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + initializer = initialize_(activation=self.spec.activation, **initializer_spec) + self.apply(initializer) + + +@dataclass +class StateActionMLPSpec(DataClassJsonMixin): + """Specifications for building an MLP with state and action inputs. + + Args: + units: Number of units in each hidden layer + activation: Nonlinearity following each linear layer + delay_action: Whether to apply an initial preprocessing layer on the + observation before concatenating the action to the input. + """ + + units: List[int] = field(default_factory=list) + activation: Optional[str] = None + delay_action: bool = False + + +class StateActionMLP(nnx.StateActionEncoder): + """Multilayer perceptron for encoding state-action inputs.""" + + spec_cls = StateActionMLPSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: StateActionMLPSpec): + obs_size = obs_space.shape[0] + action_size = action_space.shape[0] + + super().__init__( + obs_size, + action_size, + units=spec.units, + activation=spec.activation, + delay_action=spec.delay_action, + ) + self.spec = spec + + def initialize_parameters(self, initializer_spec: dict): + """Initialize all Linear models in the encoder. + + Uses `raylab.pytorch.nn.init.initialize_` to create an initializer + function. + + Args: + initializer_spec: Dictionary with mandatory `name` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. + """ + initializer = initialize_(activation=self.spec.activation, **initializer_spec) + self.apply(initializer) + + +class MLP(nn.Module): + """A general purpose Multi-Layer Perceptron.""" + + def __init__( + self, + in_features, + out_features, + hidden_features, + state_features=None, + num_blocks=2, + activation="ReLU", + activate_output=False, + ): + # pylint:disable=too-many-arguments + super().__init__() + activation = get_activation(activation) + self.stateful = bool(state_features) + if self.stateful: + self.initial_layer = nn.Linear( + in_features + state_features, hidden_features + ) + else: + self.initial_layer = nn.Linear(in_features, hidden_features) + + layers = [activation()] + layers += [ + layer + for _ in range(num_blocks) + for layer in (nn.Linear(hidden_features, hidden_features), activation()) + ] + layers += [nn.Linear(hidden_features, out_features)] + + if activate_output: + layers += [activation()] + + self.sequential = nn.Sequential(*layers) + + def forward(self, inputs, params: Optional[Dict[str, torch.Tensor]] = None): + # pylint:disable=arguments-differ + if self.stateful: + if params is None: + raise ValueError("Parameters required for stateful mlp.") + out = self.initial_layer(torch.cat([inputs, params["state"]], dim=-1)) + else: + out = self.initial_layer(inputs) + + return self.sequential(out) diff --git a/raylab/modules/networks/resnet.py b/raylab/policy/modules/networks/resnet.py similarity index 100% rename from raylab/modules/networks/resnet.py rename to raylab/policy/modules/networks/resnet.py diff --git a/raylab/policy/modules/sac.py b/raylab/policy/modules/sac.py new file mode 100644 index 00000000..02e179f7 --- /dev/null +++ b/raylab/policy/modules/sac.py @@ -0,0 +1,69 @@ +"""NN architecture used in Soft Actor-Critic.""" +from dataclasses import dataclass +from dataclasses import field + +import torch.nn as nn +from dataclasses_json import DataClassJsonMixin +from gym.spaces import Box + +from .actor.stochastic import StochasticActor +from .critic.action_value import ActionValueCritic + +ActorSpec = StochasticActor.spec_cls +CriticSpec = ActionValueCritic.spec_cls + + +@dataclass +class SACSpec(DataClassJsonMixin): + """Specifications for SAC modules + + Args: + actor: Specifications for stochastic policy and entropy coefficient + critic: Specifications for action-value estimators + initializer: Optional dictionary with mandatory `type` key corresponding + to the initializer function name in `torch.nn.init` and optional + keyword arguments. Overrides actor and critic initializer + specifications. + """ + + actor: ActorSpec = field(default_factory=ActorSpec) + critic: CriticSpec = field(default_factory=CriticSpec) + initializer: dict = field(default_factory=dict) + + def __post_init__(self): + # Top-level initializer options take precedence over individual + # component's options + if self.initializer: + self.actor.initializer = self.initializer + self.critic.initializer = self.initializer + + +class SAC(nn.Module): + """NN module for Soft Actor-Critic algorithms. + + Args: + obs_space: Observation space + action_space: Action space + spec: Specifications for SAC modules + + Attributes: + actor (StochasticPolicy): Stochastic policy to be learned + alpha (Alpha): Entropy bonus coefficient + critics (QValueEnsemble): The action-value estimators to be learned + target_critics (QValueEnsemble): The action-value estimators used for + bootstrapping in Q-Learning + spec_cls: Expected class of `spec` init argument + """ + + # pylint:disable=abstract-method + spec_cls = SACSpec + + def __init__(self, obs_space: Box, action_space: Box, spec: SACSpec): + super().__init__() + actor = StochasticActor(obs_space, action_space, spec.actor) + self.actor = actor.policy + self.alpha = actor.alpha + + critic = ActionValueCritic(obs_space, action_space, spec.critic) + self.critics = critic.q_values + self.target_critics = critic.target_q_values diff --git a/tests/agents/naf/__init__.py b/raylab/policy/modules/v0/__init__.py similarity index 100% rename from tests/agents/naf/__init__.py rename to raylab/policy/modules/v0/__init__.py diff --git a/raylab/modules/abstract.py b/raylab/policy/modules/v0/abstract.py similarity index 100% rename from raylab/modules/abstract.py rename to raylab/policy/modules/v0/abstract.py diff --git a/raylab/policy/modules/v0/catalog.py b/raylab/policy/modules/v0/catalog.py new file mode 100644 index 00000000..e72bcb8c --- /dev/null +++ b/raylab/policy/modules/v0/catalog.py @@ -0,0 +1,50 @@ +"""Registry of old modules for PyTorch policies.""" +import torch.nn as nn +from gym.spaces import Space + +from .ddpg_module import DDPGModule +from .maxent_model_based import MaxEntModelBased +from .model_based_ddpg import ModelBasedDDPG +from .model_based_sac import ModelBasedSAC +from .naf_module import NAFModule +from .nfmbrl import NFMBRL +from .off_policy_nfac import OffPolicyNFAC +from .on_policy_actor_critic import OnPolicyActorCritic +from .on_policy_nfac import OnPolicyNFAC +from .sac_module import SACModule +from .simple_model_based import SimpleModelBased +from .svg_module import SVGModule +from .svg_realnvp_actor import SVGRealNVPActor +from .trpo_tang2018 import TRPOTang2018 + +MODULES = { + cls.__name__: cls + for cls in ( + NAFModule, + DDPGModule, + SACModule, + SimpleModelBased, + SVGModule, + MaxEntModelBased, + ModelBasedDDPG, + ModelBasedSAC, + NFMBRL, + OnPolicyActorCritic, + OnPolicyNFAC, + OffPolicyNFAC, + TRPOTang2018, + SVGRealNVPActor, + ) +} + + +def get_module(obs_space: Space, action_space: Space, config: dict) -> nn.Module: + """Retrieve and construct module of given name. + + Args: + obs_space: Observation space + action_space: Action space + config: Configurations for module construction and initialization + """ + type_ = config.pop("type") + return MODULES[type_](obs_space, action_space, config) diff --git a/raylab/modules/ddpg_module.py b/raylab/policy/modules/v0/ddpg_module.py similarity index 100% rename from raylab/modules/ddpg_module.py rename to raylab/policy/modules/v0/ddpg_module.py diff --git a/raylab/modules/maxent_model_based.py b/raylab/policy/modules/v0/maxent_model_based.py similarity index 100% rename from raylab/modules/maxent_model_based.py rename to raylab/policy/modules/v0/maxent_model_based.py diff --git a/raylab/modules/mixins/__init__.py b/raylab/policy/modules/v0/mixins/__init__.py similarity index 100% rename from raylab/modules/mixins/__init__.py rename to raylab/policy/modules/v0/mixins/__init__.py diff --git a/raylab/modules/mixins/action_value_mixin.py b/raylab/policy/modules/v0/mixins/action_value_mixin.py similarity index 100% rename from raylab/modules/mixins/action_value_mixin.py rename to raylab/policy/modules/v0/mixins/action_value_mixin.py diff --git a/raylab/modules/mixins/deterministic_actor_mixin.py b/raylab/policy/modules/v0/mixins/deterministic_actor_mixin.py similarity index 100% rename from raylab/modules/mixins/deterministic_actor_mixin.py rename to raylab/policy/modules/v0/mixins/deterministic_actor_mixin.py diff --git a/raylab/modules/mixins/normalizing_flow_actor_mixin.py b/raylab/policy/modules/v0/mixins/normalizing_flow_actor_mixin.py similarity index 98% rename from raylab/modules/mixins/normalizing_flow_actor_mixin.py rename to raylab/policy/modules/v0/mixins/normalizing_flow_actor_mixin.py index fb70318f..88a9a610 100644 --- a/raylab/modules/mixins/normalizing_flow_actor_mixin.py +++ b/raylab/policy/modules/v0/mixins/normalizing_flow_actor_mixin.py @@ -6,11 +6,11 @@ import torch.nn as nn from ray.rllib.utils import override +import raylab.policy.modules.networks as networks import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd from raylab.utils.dictionaries import deep_merge -from .. import networks from .stochastic_actor_mixin import StochasticPolicy diff --git a/raylab/modules/mixins/normalizing_flow_model_mixin.py b/raylab/policy/modules/v0/mixins/normalizing_flow_model_mixin.py similarity index 99% rename from raylab/modules/mixins/normalizing_flow_model_mixin.py rename to raylab/policy/modules/v0/mixins/normalizing_flow_model_mixin.py index 4f73ae7c..031cc061 100644 --- a/raylab/modules/mixins/normalizing_flow_model_mixin.py +++ b/raylab/policy/modules/v0/mixins/normalizing_flow_model_mixin.py @@ -5,11 +5,11 @@ import torch.nn as nn from ray.rllib.utils import override +import raylab.policy.modules.networks as networks import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd from raylab.utils.dictionaries import deep_merge -from .. import networks from .stochastic_model_mixin import StochasticModel from .stochastic_model_mixin import StochasticModelMixin diff --git a/raylab/modules/mixins/state_value_mixin.py b/raylab/policy/modules/v0/mixins/state_value_mixin.py similarity index 100% rename from raylab/modules/mixins/state_value_mixin.py rename to raylab/policy/modules/v0/mixins/state_value_mixin.py diff --git a/raylab/modules/mixins/stochastic_actor_mixin.py b/raylab/policy/modules/v0/mixins/stochastic_actor_mixin.py similarity index 100% rename from raylab/modules/mixins/stochastic_actor_mixin.py rename to raylab/policy/modules/v0/mixins/stochastic_actor_mixin.py diff --git a/raylab/modules/mixins/stochastic_model_mixin.py b/raylab/policy/modules/v0/mixins/stochastic_model_mixin.py similarity index 100% rename from raylab/modules/mixins/stochastic_model_mixin.py rename to raylab/policy/modules/v0/mixins/stochastic_model_mixin.py diff --git a/raylab/modules/mixins/svg_model_mixin.py b/raylab/policy/modules/v0/mixins/svg_model_mixin.py similarity index 95% rename from raylab/modules/mixins/svg_model_mixin.py rename to raylab/policy/modules/v0/mixins/svg_model_mixin.py index cb17c754..5adbb3b7 100644 --- a/raylab/modules/mixins/svg_model_mixin.py +++ b/raylab/policy/modules/v0/mixins/svg_model_mixin.py @@ -54,7 +54,7 @@ def make_param(in_features): kwargs = dict(event_size=1, input_dependent_scale=False) return nnx.NormalParams(in_features, **kwargs) - self.params = nn.ModuleList([make_param(l.out_features) for l in self.logits]) + self.params = nn.ModuleList([make_param(m.out_features) for m in self.logits]) @override(nn.Module) def forward(self, obs, act): # pylint: disable=arguments-differ diff --git a/raylab/modules/model_based_ddpg.py b/raylab/policy/modules/v0/model_based_ddpg.py similarity index 100% rename from raylab/modules/model_based_ddpg.py rename to raylab/policy/modules/v0/model_based_ddpg.py diff --git a/raylab/modules/model_based_sac.py b/raylab/policy/modules/v0/model_based_sac.py similarity index 100% rename from raylab/modules/model_based_sac.py rename to raylab/policy/modules/v0/model_based_sac.py diff --git a/raylab/modules/naf_module.py b/raylab/policy/modules/v0/naf_module.py similarity index 93% rename from raylab/modules/naf_module.py rename to raylab/policy/modules/v0/naf_module.py index e57f745b..5f50c61b 100644 --- a/raylab/modules/naf_module.py +++ b/raylab/policy/modules/v0/naf_module.py @@ -5,14 +5,13 @@ import torch.nn as nn from ray.rllib.utils import override +from raylab.policy.modules.actor.policy.deterministic import DeterministicPolicy from raylab.pytorch.nn import FullyConnected from raylab.pytorch.nn import NormalizedLinear from raylab.pytorch.nn import TanhSquash from raylab.pytorch.nn import TrilMatrix from raylab.utils.dictionaries import deep_merge -from .mixins import DeterministicPolicy - BASE_CONFIG = { "double_q": False, @@ -68,8 +67,7 @@ def _make_encoder(obs_space, config): def _make_actor(self, obs_space, action_space, config): naf = self.critics[0] - mods = nn.ModuleList([naf.logits, naf.pre_act, naf.squash]) - actor = DeterministicPolicy(mods) + actor = DeterministicPolicy(naf.logits, naf.pre_act, naf.squash) behavior = actor if config["perturbed_policy"]: if not config["encoder"].get("layer_norm"): @@ -77,7 +75,9 @@ def _make_actor(self, obs_space, action_space, config): "'layer_norm' is deactivated even though a perturbed policy was " "requested. For optimal stability, set 'layer_norm': True." ) - behavior = DeterministicPolicy.from_scratch(obs_space, action_space, config) + logits = self._make_encoder(obs_space, config) + _naf = NAF(logits, action_space, config) + behavior = DeterministicPolicy(_naf.logits, _naf.pre_act, _naf.squash) return {"actor": actor, "behavior": behavior} diff --git a/raylab/modules/nfmbrl.py b/raylab/policy/modules/v0/nfmbrl.py similarity index 100% rename from raylab/modules/nfmbrl.py rename to raylab/policy/modules/v0/nfmbrl.py diff --git a/raylab/modules/off_policy_nfac.py b/raylab/policy/modules/v0/off_policy_nfac.py similarity index 100% rename from raylab/modules/off_policy_nfac.py rename to raylab/policy/modules/v0/off_policy_nfac.py diff --git a/raylab/modules/on_policy_actor_critic.py b/raylab/policy/modules/v0/on_policy_actor_critic.py similarity index 100% rename from raylab/modules/on_policy_actor_critic.py rename to raylab/policy/modules/v0/on_policy_actor_critic.py diff --git a/raylab/modules/on_policy_nfac.py b/raylab/policy/modules/v0/on_policy_nfac.py similarity index 100% rename from raylab/modules/on_policy_nfac.py rename to raylab/policy/modules/v0/on_policy_nfac.py diff --git a/raylab/modules/sac_module.py b/raylab/policy/modules/v0/sac_module.py similarity index 100% rename from raylab/modules/sac_module.py rename to raylab/policy/modules/v0/sac_module.py diff --git a/raylab/modules/simple_model_based.py b/raylab/policy/modules/v0/simple_model_based.py similarity index 100% rename from raylab/modules/simple_model_based.py rename to raylab/policy/modules/v0/simple_model_based.py diff --git a/raylab/modules/svg_module.py b/raylab/policy/modules/v0/svg_module.py similarity index 100% rename from raylab/modules/svg_module.py rename to raylab/policy/modules/v0/svg_module.py diff --git a/raylab/modules/svg_realnvp_actor.py b/raylab/policy/modules/v0/svg_realnvp_actor.py similarity index 100% rename from raylab/modules/svg_realnvp_actor.py rename to raylab/policy/modules/v0/svg_realnvp_actor.py diff --git a/raylab/modules/trpo_tang2018.py b/raylab/policy/modules/v0/trpo_tang2018.py similarity index 99% rename from raylab/modules/trpo_tang2018.py rename to raylab/policy/modules/v0/trpo_tang2018.py index a93e3957..e83f32c0 100644 --- a/raylab/modules/trpo_tang2018.py +++ b/raylab/policy/modules/v0/trpo_tang2018.py @@ -7,6 +7,7 @@ from ray.rllib.utils import merge_dicts from ray.rllib.utils import override +import raylab.policy.modules.networks as networks from raylab.pytorch.nn import FullyConnected from raylab.pytorch.nn.distributions import flows from raylab.pytorch.nn.distributions import Independent @@ -16,7 +17,6 @@ from raylab.pytorch.nn.distributions.flows import TanhSquashTransform from raylab.pytorch.nn.init import initialize_ -from . import networks from .abstract import AbstractActorCritic from .mixins import StateValueMixin from .mixins import StochasticActorMixin diff --git a/raylab/policy/target_networks_mixin.py b/raylab/policy/target_networks_mixin.py deleted file mode 100644 index 6638e7c0..00000000 --- a/raylab/policy/target_networks_mixin.py +++ /dev/null @@ -1,20 +0,0 @@ -# pylint: disable=missing-docstring -# pylint: enable=missing-docstring -from raylab.pytorch.nn.utils import update_polyak - - -class TargetNetworksMixin: - """Adds method to update target networks by name.""" - - # pylint: disable=too-few-public-methods - - def update_targets(self, module, target_module): - """Update target networks through one step of polyak averaging. - - Arguments: - module (str): name of primary module in the policy's module dict - target_module (str): name of target module in the policy's module dict - """ - update_polyak( - self.module[module], self.module[target_module], self.config["polyak"] - ) diff --git a/raylab/policy/torch_policy.py b/raylab/policy/torch_policy.py index 055efb36..a866d809 100644 --- a/raylab/policy/torch_policy.py +++ b/raylab/policy/torch_policy.py @@ -22,11 +22,10 @@ from torch.optim import Optimizer from raylab.agents import Trainer -from raylab.modules.catalog import get_module from raylab.pytorch.utils import convert_to_tensor from raylab.utils.dictionaries import deep_merge -from .action_dist import WrapModuleDist +from .modules.catalog import get_module from .optimizer_collection import OptimizerCollection @@ -34,6 +33,8 @@ class TorchPolicy(Policy): """A Policy that uses PyTorch as a backend. Attributes: + dist_class: Action distribution class for computing actions. Must be set + by subclasses before calling `__init__`. device: Device in which the parameter tensors reside. All input samples will be converted to tensors and moved to this device module: The policy's neural network module. Should be compilable to @@ -53,6 +54,8 @@ def __init__(self, observation_space: Space, action_space: Space, config: dict): whitelist=Trainer._allow_unknown_subkeys, override_all_if_type_changes=Trainer._override_all_subkeys_if_type_changes, ) + # Allow subclasses to set `dist_class` before calling init + action_dist = getattr(self, "dist_class", None) super().__init__(observation_space, action_space, config) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -64,10 +67,19 @@ def __init__(self, observation_space: Space, action_space: Space, config: dict): self.optimizers[name] = optimizer # === Policy attributes === - self.dist_class = WrapModuleDist + self.dist_class = action_dist + self.dist_class.check_model_compat(self.module) self.framework = "torch" # Needed to create exploration self.exploration = self._create_exploration() + @property + def model(self): + """The policy's NN module. + + Mostly for compatibility with RLlib's API. + """ + return self.module + @staticmethod @abstractmethod def get_default_config() -> dict: @@ -142,7 +154,9 @@ def compute_actions( self.convert_to_tensor([1]), ) + # pylint:disable=not-callable action_dist = self.dist_class(dist_inputs, self.module) + # pylint:enable=not-callable actions, logp = self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore ) @@ -243,7 +257,9 @@ def compute_log_likelihoods( state_batches, self.convert_to_tensor([1]), ) + # pylint:disable=not-callable action_dist = self.dist_class(dist_inputs, self.module) + # pylint:enable=not-callable log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) return log_likelihoods diff --git a/raylab/pytorch/nn/init.py b/raylab/pytorch/nn/init.py index 3f1f43c7..fe0f1a5f 100644 --- a/raylab/pytorch/nn/init.py +++ b/raylab/pytorch/nn/init.py @@ -1,16 +1,23 @@ """Utilities for module initialization.""" import functools import inspect +from typing import Callable +from typing import Optional +from typing import Union import torch.nn as nn +from torch import Tensor -def get_initializer(name): +def get_initializer(name: Optional[str]) -> Callable[[Tensor], None]: """Return initializer function given its name. Arguments: - name (str): the initializer function's name + name: The initializer function's name. If None, returns a no-op callable """ + if name is None: + return lambda _: None + name_ = name + "_" if name in dir(nn.init) and name_ in dir(nn.init): func = getattr(nn.init, name_) @@ -27,16 +34,19 @@ def get_initializer(name): } -def initialize_(name, activation=None, **options): +def initialize_( + name: Optional[str] = None, activation: Union[str, dict] = None, **options +) -> Callable[[nn.Module], None]: """Return a callable to apply an initializer with the given name and options. If `gain` is part of the initializer's argspec and is not specified in options, - the recommended value from `nn.init.calculate_gain` is used. + the recommended value from `torch.nn.init.calculate_gain` is used. Arguments: - name (str): name of initializer function - activation (str, dict): activation function following linear layer, optional - **options: keyword arguments to be passed to the initializer + name: Initializer function name + activation: Optional specification of the activation function that + follows linear layers + **options: Keyword arguments to pass to the initializer Returns: A callable to be used with `nn.Module.apply`. diff --git a/raylab/pytorch/nn/modules/tanh_squash.py b/raylab/pytorch/nn/modules/tanh_squash.py index 3bbfd968..909c283b 100644 --- a/raylab/pytorch/nn/modules/tanh_squash.py +++ b/raylab/pytorch/nn/modules/tanh_squash.py @@ -1,16 +1,22 @@ # pylint: disable=missing-docstring +import torch import torch.nn as nn -from ray.rllib.utils import override +from torch import Tensor class TanhSquash(nn.Module): """Neural network module squashing vectors to specified range using Tanh.""" - def __init__(self, low, high): + def __init__(self, low: Tensor, high: Tensor): super().__init__() self.register_buffer("loc", (high + low) / 2) self.register_buffer("scale", (high - low) / 2) - @override(nn.Module) - def forward(self, inputs): # pylint: disable=arguments-differ + def forward(self, inputs: Tensor, reverse: bool = False) -> Tensor: + # pylint: disable=arguments-differ + if reverse: + inputs = (inputs - self.loc) / self.scale + to_log1 = torch.clamp(1 + inputs, min=1.1754943508222875e-38) + to_log2 = torch.clamp(1 - inputs, min=1.1754943508222875e-38) + return (torch.log(to_log1) - torch.log(to_log2)) / 2 return self.loc + inputs.tanh() * self.scale diff --git a/raylab/utils/annotations.py b/raylab/utils/annotations.py index 1a124fe0..c19b5e02 100644 --- a/raylab/utils/annotations.py +++ b/raylab/utils/annotations.py @@ -1,9 +1,11 @@ """Collection of type annotations.""" from typing import Callable +from typing import Dict from typing import Tuple from torch import Tensor +TensorDict = Dict[str, Tensor] RewardFn = Callable[[Tensor, Tensor, Tensor], Tensor] TerminationFn = Callable[[Tensor, Tensor, Tensor], Tensor] DynamicsFn = Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]] diff --git a/raylab/utils/exploration/base.py b/raylab/utils/exploration/base.py new file mode 100644 index 00000000..320d45d5 --- /dev/null +++ b/raylab/utils/exploration/base.py @@ -0,0 +1,54 @@ +"""Base implementations for all exploration strategies.""" +from abc import ABCMeta +from abc import abstractmethod +from typing import Optional + +import torch.nn as nn +from ray.rllib.utils.exploration import Exploration + +Model = Optional[nn.Module] + + +class IncompatibleExplorationError(Exception): + """Exception raised for incompatible exploration and NN module. + + Args: + exp_cls: Exploration class + module: NN module + err: AssertionError explaining the reason why exploration and module are + incompatible + + Attributes: + message: Human-readable text explaining what caused the incompatibility + """ + + def __init__(self, exp_cls: type, module: Model, err: Exception): + # pylint:disable=unused-argument + msg = ( + f"Exploration type {exp_cls} is incompatible with NN module of type" + " {type(module)}. Reason:\n" + " {err}" + ) + super().__init__(msg) + self.message = msg + + +class BaseExploration(Exploration, metaclass=ABCMeta): + """Base class for exploration objects.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + try: + self.check_model_compat(self.model) + except AssertionError as err: + raise IncompatibleExplorationError(type(self), self.model, err) + + @classmethod + @abstractmethod + def check_model_compat(cls, model: Model): + """Assert the given NN module is compatible with the exploration. + + Raises: + IncompatibleDistClsError: If `model` is incompatible with the + exploration class + """ diff --git a/raylab/utils/exploration/gaussian_noise.py b/raylab/utils/exploration/gaussian_noise.py index b57e9fe1..8280cfb0 100644 --- a/raylab/utils/exploration/gaussian_noise.py +++ b/raylab/utils/exploration/gaussian_noise.py @@ -1,29 +1,36 @@ # pylint:disable=missing-module-docstring +from typing import Tuple + import torch -from ray.rllib.utils import override +from torch import Tensor -from raylab.pytorch.nn.distributions.flows import TanhSquashTransform +from raylab.policy.action_dist import BaseActionDist +from raylab.policy.modules.actor.policy.deterministic import DeterministicPolicy +from .base import Model from .random_uniform import RandomUniform class GaussianNoise(RandomUniform): - """Adds fixed additive gaussian exploration noise to actions before squashing. + """Adds fixed additive gaussian exploration noise to actions. Args: noise_stddev (float): Standard deviation of the Gaussian samples. """ + valid_behavior_cls = DeterministicPolicy + def __init__(self, *args, noise_stddev=None, **kwargs): super().__init__(*args, **kwargs) self._noise_stddev = noise_stddev - self._squash = TanhSquashTransform( - low=torch.as_tensor(self.action_space.low), - high=torch.as_tensor(self.action_space.high), - ) - @override(RandomUniform) - def get_exploration_action(self, *, action_distribution, timestep, explore=True): + def get_exploration_action( + self, + *, + action_distribution: BaseActionDist, + timestep: int, + explore: bool = True, + ): if explore: if timestep < self._pure_exploration_steps: return super().get_exploration_action( @@ -31,12 +38,28 @@ def get_exploration_action(self, *, action_distribution, timestep, explore=True) timestep=timestep, explore=explore, ) - return self._get_gaussian_perturbed_actions(action_distribution) + return self._inject_gaussian_noise(action_distribution) return action_distribution.deterministic_sample() - def _get_gaussian_perturbed_actions(self, action_distribution): - module, inputs = action_distribution.model, action_distribution.inputs - actions = module.actor(**inputs) - pre_squash, _ = self._squash(actions, reverse=True) - noise = torch.randn_like(pre_squash) * self._noise_stddev - return self._squash(pre_squash + noise)[0], None + def _inject_gaussian_noise( + self, action_distribution: BaseActionDist + ) -> Tuple[Tensor, None]: + model, inputs = action_distribution.model, action_distribution.inputs + unconstrained_action = model.behavior.unconstrained_action(**inputs) + unconstrained_action += ( + torch.randn_like(unconstrained_action) * self._noise_stddev + ) + action = model.behavior.squash_action(unconstrained_action) + return action, None + + @classmethod + def check_model_compat(cls, model: Model): + RandomUniform.check_model_compat(model) + assert model is not None, f"{cls} exploration needs access to the NN." + assert hasattr( + model, "behavior" + ), f"NN model {type(model)} has no behavior attribute." + assert isinstance(model.behavior, cls.valid_behavior_cls), ( + f"Expected behavior to be an instance of {cls.valid_behavior_cls};" + " found {type(model.behavior)} instead." + ) diff --git a/raylab/utils/exploration/parameter_noise.py b/raylab/utils/exploration/parameter_noise.py index d544a004..4fe200cc 100644 --- a/raylab/utils/exploration/parameter_noise.py +++ b/raylab/utils/exploration/parameter_noise.py @@ -6,22 +6,22 @@ import torch from ray.rllib import SampleBatch from ray.rllib.models.action_dist import ActionDistribution -from ray.rllib.utils import override -from ray.rllib.utils.exploration import Exploration -from ray.rllib.utils.torch_ops import convert_to_non_torch_type from raylab.policy import TorchPolicy -from raylab.pytorch.nn.distributions.flows import TanhSquashTransform from raylab.pytorch.nn.utils import perturb_params from raylab.utils.param_noise import AdaptiveParamNoiseSpec from raylab.utils.param_noise import ddpg_distance_metric +from .base import Model from .random_uniform import RandomUniform class ParameterNoise(RandomUniform): """Adds adaptive parameter noise exploration schedule to a Policy. + Expects `actor` attribute of `policy.module` to be an instance of + `raylab.policy.modules.actor.policy.deterministic.DeterministicPolicy`. + Args: param_noise_spec: Arguments for `AdaptiveParamNoiseSpec`. """ @@ -30,20 +30,14 @@ def __init__(self, *args, param_noise_spec: dict = None, **kwargs): super().__init__(*args, **kwargs) param_noise_spec = param_noise_spec or {} self._param_noise_spec = AdaptiveParamNoiseSpec(**param_noise_spec) - self._squash = TanhSquashTransform( - low=torch.as_tensor(self.action_space.low), - high=torch.as_tensor(self.action_space.high), - ) - @override(RandomUniform) def get_exploration_action( self, *, action_distribution: ActionDistribution, timestep: int, - explore: bool = True + explore: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - model, inputs = action_distribution.model, action_distribution.inputs if explore: if timestep < self._pure_exploration_steps: return super().get_exploration_action( @@ -51,17 +45,16 @@ def get_exploration_action( timestep=timestep, explore=explore, ) - return model.behavior(**inputs), None - return model.actor(**inputs), None + return action_distribution.sample() + return action_distribution.deterministic_sample() - @override(Exploration) def on_episode_start( self, policy: TorchPolicy, *, environment: Any = None, episode: Any = None, - tf_sess: Any = None + tf_sess: Any = None, ): # pylint:disable=unused-argument perturb_params( @@ -71,14 +64,12 @@ def on_episode_start( ) @torch.no_grad() - @override(Exploration) def postprocess_trajectory( self, policy: TorchPolicy, sample_batch: SampleBatch, tf_sess: Any = None ): self.update_parameter_noise(policy, sample_batch) return sample_batch - @override(Exploration) def get_info(self) -> dict: return {"param_noise_stddev": self._param_noise_spec.curr_stddev} @@ -87,12 +78,21 @@ def update_parameter_noise(self, policy: TorchPolicy, sample_batch: SampleBatch) module = policy.module cur_obs = policy.convert_to_tensor(sample_batch[SampleBatch.CUR_OBS]) actions = policy.convert_to_tensor(sample_batch[SampleBatch.ACTIONS]) - target_actions = module.actor(cur_obs) - unsquashed_acts, _ = self._squash(actions, reverse=True) - unsquashed_targs, _ = self._squash(target_actions, reverse=True) - noisy, target = map( - convert_to_non_torch_type, (unsquashed_acts, unsquashed_targs) - ) + noisy = module.actor.unsquash_action(actions) + target = module.actor.unconstrained_action(cur_obs) + noisy, target = map(lambda x: x.cpu().detach().numpy(), (noisy, target)) + distance = ddpg_distance_metric(noisy, target) self._param_noise_spec.adapt(distance) + + @classmethod + def check_model_compat(cls, model: Model): + assert ( + model is not None + ), f"Need to pass the model to {cls} to check compatibility." + actor, behavior = model.actor, model.behavior + assert set(actor.parameters()).isdisjoint(set(behavior.parameters())), ( + "Target and behavior policy cannot share parameters in parameter " + "noise exploration." + ) diff --git a/raylab/utils/exploration/random_uniform.py b/raylab/utils/exploration/random_uniform.py index 10014476..2ce3192c 100644 --- a/raylab/utils/exploration/random_uniform.py +++ b/raylab/utils/exploration/random_uniform.py @@ -1,12 +1,13 @@ # pylint:disable=missing-module-docstring import numpy as np -from ray.rllib.utils import override -from ray.rllib.utils.exploration import Exploration import raylab.pytorch.utils as ptu +from .base import BaseExploration +from .base import Model -class RandomUniform(Exploration): + +class RandomUniform(BaseExploration): """Samples actions from the Gym action space Args: @@ -28,7 +29,6 @@ def __init__(self, *args, pure_exploration_steps=0, **kwargs): ) self._pure_exploration_steps = pure_exploration_steps - @override(Exploration) def get_exploration_action(self, *, action_distribution, timestep, explore=True): # pylint:disable=unused-argument if explore: @@ -43,3 +43,7 @@ def get_exploration_action(self, *, action_distribution, timestep, explore=True) ) return acts, logp return action_distribution.deterministic_sample() + + @classmethod + def check_model_compat(cls, model: Model): + pass diff --git a/tests/conftest.py b/tests/conftest.py index b325a623..6575ab4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,8 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import logging - -import gym -import gym.spaces as spaces +# pylint:disable=missing-docstring,redefined-outer-name,protected-access import pytest from .mock_env import MockEnv -gym.logger.set_level(logging.ERROR) - # Test setup from: # https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option @@ -40,9 +34,16 @@ def init_ray(): ray.shutdown() +@pytest.fixture(autouse=True, scope="session") +def disable_gym_logger_warnings(): + import logging + import gym + + gym.logger.set_level(logging.ERROR) + + @pytest.fixture(autouse=True, scope="session") def register_envs(): - # pylint:disable=import-outside-toplevel import raylab from raylab.envs.registry import ENVS @@ -51,33 +52,3 @@ def _mock_env_maker(config): ENVS["MockEnv"] = _mock_env_maker raylab.register_all_environments() - - -@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Obs1Dim", "Obs4Dim")) -def obs_space(request): - return spaces.Box(-10, 10, shape=request.param) - - -@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Act1Dim", "Act4Dim")) -def action_space(request): - return spaces.Box(-1, 1, shape=request.param) - - -@pytest.fixture(scope="module") -def envs(): - from raylab.envs.registry import ENVS # pylint:disable=import-outside-toplevel - - return ENVS.copy() - - -ENV_IDS = ("MockEnv", "Navigation", "Reservoir", "HVAC", "MountainCarContinuous-v0") - - -@pytest.fixture(params=ENV_IDS) -def env_name(request): - return request.param - - -@pytest.fixture -def env_creator(envs, env_name): - return envs[env_name] diff --git a/tests/general/conftest.py b/tests/general/conftest.py deleted file mode 100644 index e264fbb1..00000000 --- a/tests/general/conftest.py +++ /dev/null @@ -1,16 +0,0 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import pytest - -from raylab.agents.registry import AGENTS - -TRAINER_NAMES, TRAINER_IMPORTS = zip(*AGENTS.items()) - - -@pytest.fixture(scope="module", params=TRAINER_IMPORTS, ids=TRAINER_NAMES) -def trainer_cls(request): - return request.param() - - -@pytest.fixture -def policy_cls(trainer_cls): - return trainer_cls._policy diff --git a/tests/general/test_rollout.py b/tests/general/test_rollout.py deleted file mode 100644 index c66b6fca..00000000 --- a/tests/general/test_rollout.py +++ /dev/null @@ -1,31 +0,0 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import pytest -from ray.rllib import RolloutWorker - - -@pytest.fixture -def worker_kwargs(): - return {"rollout_fragment_length": 200, "batch_mode": "truncate_episodes"} - - -def test_compute_single_action(envs, env_name, policy_cls): - env = envs[env_name]({}) - policy = policy_cls(env.observation_space, env.action_space, {"env": env_name}) - - obs = env.observation_space.sample() - action, states, info = policy.compute_single_action(obs, []) - assert action in env.action_space - assert isinstance(states, list) - assert isinstance(info, dict) - - -def test_policy_in_rollout_worker(envs, env_name, policy_cls, worker_kwargs): - env_creator = envs[env_name] - policy_config = {"env": env_name} - worker = RolloutWorker( - env_creator=env_creator, - policy=policy_cls, - policy_config=policy_config, - **worker_kwargs - ) - worker.sample() diff --git a/tests/general/test_worker.py b/tests/general/test_worker.py deleted file mode 100644 index e78afdf9..00000000 --- a/tests/general/test_worker.py +++ /dev/null @@ -1,20 +0,0 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access -import pytest -from ray.rllib import RolloutWorker -from ray.rllib import SampleBatch - - -@pytest.fixture -def worker(envs, env_name, policy_cls): - return RolloutWorker( - env_creator=envs[env_name], - policy=policy_cls, - policy_config={"env": env_name}, - rollout_fragment_length=200, - batch_mode="truncate_episodes", - ) - - -def test_collect_traj(worker): - traj = worker.sample() - assert isinstance(traj, SampleBatch) diff --git a/tests/agents/sac/__init__.py b/tests/raylab/__init__.py similarity index 100% rename from tests/agents/sac/__init__.py rename to tests/raylab/__init__.py diff --git a/tests/agents/sop/__init__.py b/tests/raylab/agents/__init__.py similarity index 100% rename from tests/agents/sop/__init__.py rename to tests/raylab/agents/__init__.py diff --git a/tests/agents/conftest.py b/tests/raylab/agents/conftest.py similarity index 100% rename from tests/agents/conftest.py rename to tests/raylab/agents/conftest.py diff --git a/tests/agents/svg/__init__.py b/tests/raylab/agents/mage/__init__.py similarity index 100% rename from tests/agents/svg/__init__.py rename to tests/raylab/agents/mage/__init__.py diff --git a/tests/agents/mage/test_policy.py b/tests/raylab/agents/mage/test_policy.py similarity index 86% rename from tests/agents/mage/test_policy.py rename to tests/raylab/agents/mage/test_policy.py index 0234f1c6..e436ae56 100644 --- a/tests/agents/mage/test_policy.py +++ b/tests/raylab/agents/mage/test_policy.py @@ -1,14 +1,12 @@ # pylint: disable=missing-docstring,redefined-outer-name,protected-access -from unittest import mock - import numpy as np import pytest import torch from raylab.agents.mage import MAGETorchPolicy -from raylab.losses import DeterministicPolicyGradient -from raylab.losses import MAGE -from raylab.losses import ModelEnsembleMLE +from raylab.policy.losses import DeterministicPolicyGradient +from raylab.policy.losses import MAGE +from raylab.policy.losses import ModelEnsembleMLE from raylab.utils.debug import fake_batch @@ -85,8 +83,8 @@ def test_learn_on_batch(policy, samples): assert np.isfinite(info["grad_norm(critics)"]) -def test_compile(policy): - with mock.patch("raylab.losses.MAGE.compile") as mocked_method: - policy.compile() - assert isinstance(policy.module, torch.jit.ScriptModule) - assert mocked_method.called +def test_compile(policy, mocker): + method = mocker.spy(MAGE, "compile") + policy.compile() + assert isinstance(policy.module, torch.jit.ScriptModule) + assert method.called diff --git a/tests/agents/mage/test_trainer.py b/tests/raylab/agents/mage/test_trainer.py similarity index 100% rename from tests/agents/mage/test_trainer.py rename to tests/raylab/agents/mage/test_trainer.py diff --git a/tests/cli/__init__.py b/tests/raylab/agents/mapo/__init__.py similarity index 100% rename from tests/cli/__init__.py rename to tests/raylab/agents/mapo/__init__.py diff --git a/tests/agents/mapo/test_policy.py b/tests/raylab/agents/mapo/test_policy.py similarity index 84% rename from tests/agents/mapo/test_policy.py rename to tests/raylab/agents/mapo/test_policy.py index 0bc5f025..62a29fcd 100644 --- a/tests/agents/mapo/test_policy.py +++ b/tests/raylab/agents/mapo/test_policy.py @@ -1,6 +1,5 @@ # pylint: disable=missing-docstring,redefined-outer-name,protected-access import copy -from unittest import mock import numpy as np import pytest @@ -9,11 +8,11 @@ from raylab.agents.mapo import MAPOTorchPolicy from raylab.agents.sac import SACTorchPolicy -from raylab.losses import MAPO -from raylab.losses import SoftCDQLearning -from raylab.losses import SPAML from raylab.policy import EnvFnMixin from raylab.policy import ModelTrainingMixin +from raylab.policy.losses import MAPO +from raylab.policy.losses import SoftCDQLearning +from raylab.policy.losses import SPAML from raylab.utils.debug import fake_batch @@ -79,11 +78,11 @@ def test_learn_on_batch(policy, sample_batch): assert all(not torch.allclose(o, n) for o, n in zip(old_params, new_params)) -def test_compile(policy): - with mock.patch("raylab.losses.MAPO.compile") as mapo, mock.patch( - "raylab.losses.SPAML.compile" - ) as spaml: - policy.compile() - assert isinstance(policy.module, torch.jit.ScriptModule) - assert mapo.called - assert spaml.called +def test_compile(policy, mocker): + mapo = mocker.patch("raylab.policy.losses.MAPO.compile") + spaml = mocker.patch("raylab.policy.losses.SPAML.compile") + + policy.compile() + assert isinstance(policy.module, torch.jit.ScriptModule) + assert mapo.called + assert spaml.called diff --git a/tests/agents/mapo/test_trainer.py b/tests/raylab/agents/mapo/test_trainer.py similarity index 100% rename from tests/agents/mapo/test_trainer.py rename to tests/raylab/agents/mapo/test_trainer.py diff --git a/tests/envs/__init__.py b/tests/raylab/agents/mbpo/__init__.py similarity index 100% rename from tests/envs/__init__.py rename to tests/raylab/agents/mbpo/__init__.py diff --git a/tests/agents/mbpo/conftest.py b/tests/raylab/agents/mbpo/conftest.py similarity index 100% rename from tests/agents/mbpo/conftest.py rename to tests/raylab/agents/mbpo/conftest.py diff --git a/tests/agents/mbpo/test_policy.py b/tests/raylab/agents/mbpo/test_policy.py similarity index 82% rename from tests/agents/mbpo/test_policy.py rename to tests/raylab/agents/mbpo/test_policy.py index f887e931..e792ed94 100644 --- a/tests/agents/mbpo/test_policy.py +++ b/tests/raylab/agents/mbpo/test_policy.py @@ -22,7 +22,7 @@ def config(ensemble_size): "patience_epochs": 5, }, "model_sampling": {"rollout_schedule": [(0, 10)], "num_elites": 1}, - "module": {"ensemble_size": ensemble_size}, + "module": {"model": {"ensemble_size": ensemble_size}}, } @@ -32,10 +32,8 @@ def policy(policy_cls, config): def test_policy_creation(policy): - assert "models" in policy.module - assert "actor" in policy.module - assert "critics" in policy.module - assert "alpha" in policy.module + for attr in "models actor alpha critics".split(): + assert hasattr(policy.module, attr) assert "models" in policy.optimizers assert "actor" in policy.optimizers diff --git a/tests/agents/mbpo/test_trainer.py b/tests/raylab/agents/mbpo/test_trainer.py similarity index 100% rename from tests/agents/mbpo/test_trainer.py rename to tests/raylab/agents/mbpo/test_trainer.py diff --git a/tests/envs/environments/__init__.py b/tests/raylab/agents/naf/__init__.py similarity index 100% rename from tests/envs/environments/__init__.py rename to tests/raylab/agents/naf/__init__.py diff --git a/tests/agents/naf/test_policy.py b/tests/raylab/agents/naf/test_policy.py similarity index 100% rename from tests/agents/naf/test_policy.py rename to tests/raylab/agents/naf/test_policy.py diff --git a/tests/general/__init__.py b/tests/raylab/agents/sac/__init__.py similarity index 100% rename from tests/general/__init__.py rename to tests/raylab/agents/sac/__init__.py diff --git a/tests/agents/sac/conftest.py b/tests/raylab/agents/sac/conftest.py similarity index 100% rename from tests/agents/sac/conftest.py rename to tests/raylab/agents/sac/conftest.py diff --git a/tests/agents/sac/test_actor.py b/tests/raylab/agents/sac/test_actor.py similarity index 60% rename from tests/agents/sac/test_actor.py rename to tests/raylab/agents/sac/test_actor.py index 97cbd8e9..9ce44bab 100644 --- a/tests/agents/sac/test_actor.py +++ b/tests/raylab/agents/sac/test_actor.py @@ -5,16 +5,23 @@ @pytest.fixture(params=(True, False)) def input_dependent_scale(request): - return {"module": {"actor": {"input_dependent_scale": request.param}}} + return request.param @pytest.fixture(params=(True, False)) -def clipped_double_q(request): - return {"clipped_double_q": request.param} - - -def test_actor_loss(policy_and_batch_fn, clipped_double_q, input_dependent_scale): - policy, batch = policy_and_batch_fn({**clipped_double_q, **input_dependent_scale}) +def double_q(request): + return request.param + + +def test_actor_loss(policy_and_batch_fn, double_q, input_dependent_scale): + policy, batch = policy_and_batch_fn( + { + "module": { + "actor": {"input_dependent_scale": input_dependent_scale}, + "critic": {"double_q": double_q}, + } + } + ) loss, info = policy.loss_actor(batch) assert loss.shape == () diff --git a/tests/agents/sac/test_critics.py b/tests/raylab/agents/sac/test_critics.py similarity index 82% rename from tests/agents/sac/test_critics.py rename to tests/raylab/agents/sac/test_critics.py index 551b537c..7ac909a2 100644 --- a/tests/agents/sac/test_critics.py +++ b/tests/raylab/agents/sac/test_critics.py @@ -5,17 +5,17 @@ from ray.rllib import SampleBatch import raylab.utils.dictionaries as dutil -from raylab.losses import SoftCDQLearning +from raylab.policy.losses import SoftCDQLearning @pytest.fixture(params=(True, False)) -def clipped_double_q(request): +def double_q(request): return request.param @pytest.fixture -def policy_and_batch(policy_and_batch_fn, clipped_double_q): - config = {"clipped_double_q": clipped_double_q, "polyak": 0.5} +def policy_and_batch(policy_and_batch_fn, double_q): + config = {"module": {"critic": {"double_q": double_q}}, "polyak": 0.5} return policy_and_batch_fn(config) @@ -75,14 +75,8 @@ def test_critic_loss(policy_and_batch): ) -def test_target_params_update(policy_and_batch): +def test_target_net_init(policy_and_batch): policy, _ = policy_and_batch params = list(policy.module.critics.parameters()) target_params = list(policy.module.target_critics.parameters()) assert all(torch.allclose(p, q) for p, q in zip(params, target_params)) - - old_params = [p.clone() for p in target_params] - for param in params: - param.data.add_(torch.ones_like(param)) - policy.update_targets("critics", "target_critics") - assert all(not torch.allclose(p, q) for p, q in zip(target_params, old_params)) diff --git a/tests/agents/sac/test_entropy_coeff.py b/tests/raylab/agents/sac/test_entropy_coeff.py similarity index 100% rename from tests/agents/sac/test_entropy_coeff.py rename to tests/raylab/agents/sac/test_entropy_coeff.py diff --git a/tests/losses/__init__.py b/tests/raylab/agents/sop/__init__.py similarity index 100% rename from tests/losses/__init__.py rename to tests/raylab/agents/sop/__init__.py diff --git a/tests/agents/sop/test_policy.py b/tests/raylab/agents/sop/test_policy.py similarity index 72% rename from tests/agents/sop/test_policy.py rename to tests/raylab/agents/sop/test_policy.py index 4dc4118c..94453900 100644 --- a/tests/agents/sop/test_policy.py +++ b/tests/raylab/agents/sop/test_policy.py @@ -7,13 +7,13 @@ @pytest.fixture(params=(True, False)) -def clipped_double_q(request): +def double_q(request): return request.param @pytest.fixture -def config(clipped_double_q): - return {"clipped_double_q": clipped_double_q, "policy_delay": 2} +def config(double_q): + return {"module": {"critic": {"double_q": double_q}}, "policy_delay": 2} @pytest.fixture @@ -21,17 +21,11 @@ def policy(obs_space, action_space, config): return SOPTorchPolicy(obs_space, action_space, config) -def test_target_params_update(policy): +def test_target_critics_init(policy): params = list(policy.module.critics.parameters()) target_params = list(policy.module.target_critics.parameters()) assert all(torch.allclose(p, q) for p, q in zip(params, target_params)) - old_params = [p.clone() for p in target_params] - for param in params: - param.data.add_(torch.ones_like(param)) - policy.update_targets("critics", "target_critics") - assert all(not torch.allclose(p, q) for p, q in zip(target_params, old_params)) - @pytest.fixture def samples(obs_space, action_space): diff --git a/tests/modules/__init__.py b/tests/raylab/agents/svg/__init__.py similarity index 100% rename from tests/modules/__init__.py rename to tests/raylab/agents/svg/__init__.py diff --git a/tests/agents/svg/conftest.py b/tests/raylab/agents/svg/conftest.py similarity index 100% rename from tests/agents/svg/conftest.py rename to tests/raylab/agents/svg/conftest.py diff --git a/tests/agents/svg/test_rollout_module.py b/tests/raylab/agents/svg/test_rollout_module.py similarity index 100% rename from tests/agents/svg/test_rollout_module.py rename to tests/raylab/agents/svg/test_rollout_module.py diff --git a/tests/agents/svg/test_svg_one.py b/tests/raylab/agents/svg/test_svg_one.py similarity index 100% rename from tests/agents/svg/test_svg_one.py rename to tests/raylab/agents/svg/test_svg_one.py diff --git a/tests/agents/svg/test_value_function.py b/tests/raylab/agents/svg/test_value_function.py similarity index 77% rename from tests/agents/svg/test_value_function.py rename to tests/raylab/agents/svg/test_value_function.py index 7a74bb85..ed0f88a5 100644 --- a/tests/agents/svg/test_value_function.py +++ b/tests/raylab/agents/svg/test_value_function.py @@ -40,16 +40,3 @@ def test_importance_sampling_weighted_loss(policy_and_batch): assert all(p.grad is None for p in other_params) assert "loss(critic)" in info - - -def test_target_params_update(policy_and_batch): - policy, _ = policy_and_batch - - old_params = [p.clone() for p in policy.module.target_critic.parameters()] - for param in policy.module.critic.parameters(): - param.data.add_(torch.ones_like(param)) - policy.update_targets("critic", "target_critic") - assert all( - not torch.allclose(p, p_) - for p, p_ in zip(policy.module.target_critic.parameters(), old_params) - ) diff --git a/tests/general/test_trainer.py b/tests/raylab/agents/test_registry.py similarity index 52% rename from tests/general/test_trainer.py rename to tests/raylab/agents/test_registry.py index 331bc592..5dfe46a8 100644 --- a/tests/general/test_trainer.py +++ b/tests/raylab/agents/test_registry.py @@ -1,7 +1,23 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access +# pylint:disable=missing-docstring,redefined-outer-name,protected-access from collections import defaultdict import pytest +from ray.rllib import RolloutWorker +from ray.rllib import SampleBatch + +from raylab.agents.registry import AGENTS + +TRAINER_NAMES, TRAINER_IMPORTS = zip(*AGENTS.items()) + + +@pytest.fixture(scope="module", params=TRAINER_IMPORTS, ids=TRAINER_NAMES) +def trainer_cls(request): + return request.param() + + +@pytest.fixture +def policy_cls(trainer_cls): + return trainer_cls._policy CONFIG = defaultdict( @@ -54,3 +70,34 @@ def test_trainer_eval(trainer): def test_trainer_restore(trainer): obj = trainer.save_to_object() trainer.restore_from_object(obj) + + +@pytest.fixture +def worker_kwargs(): + return {"rollout_fragment_length": 200, "batch_mode": "truncate_episodes"} + + +@pytest.fixture +def worker(envs, env_name, policy_cls, worker_kwargs): + return RolloutWorker( + env_creator=envs[env_name], + policy=policy_cls, + policy_config={"env": env_name}, + **worker_kwargs, + ) + + +def test_compute_single_action(envs, env_name, policy_cls): + env = envs[env_name]({}) + policy = policy_cls(env.observation_space, env.action_space, {"env": env_name}) + + obs = env.observation_space.sample() + action, states, info = policy.compute_single_action(obs, []) + assert action in env.action_space + assert isinstance(states, list) + assert isinstance(info, dict) + + +def test_policy_in_rollout_worker(worker): + traj = worker.sample() + assert isinstance(traj, SampleBatch) diff --git a/tests/agents/test_trainer.py b/tests/raylab/agents/test_trainer.py similarity index 100% rename from tests/agents/test_trainer.py rename to tests/raylab/agents/test_trainer.py diff --git a/tests/modules/networks/__init__.py b/tests/raylab/cli/__init__.py similarity index 100% rename from tests/modules/networks/__init__.py rename to tests/raylab/cli/__init__.py diff --git a/tests/cli/test_cli.py b/tests/raylab/cli/test_cli.py similarity index 100% rename from tests/cli/test_cli.py rename to tests/raylab/cli/test_cli.py diff --git a/tests/raylab/conftest.py b/tests/raylab/conftest.py new file mode 100644 index 00000000..df25a1d2 --- /dev/null +++ b/tests/raylab/conftest.py @@ -0,0 +1,38 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import gym.spaces as spaces +import pytest + + +@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Obs1Dim", "Obs4Dim")) +def obs_space(request): + return spaces.Box(-10, 10, shape=request.param) + + +@pytest.fixture(scope="module", params=((1,), (4,)), ids=("Act1Dim", "Act4Dim")) +def action_space(request): + return spaces.Box(-1, 1, shape=request.param) + + +@pytest.fixture(scope="module") +def envs(): + from raylab.envs.registry import ENVS # pylint:disable=import-outside-toplevel + + return ENVS.copy() + + +@pytest.fixture( + params=""" + MockEnv + Navigation + Reservoir + HVAC + MountainCarContinuous-v0 + """.split() +) +def env_name(request): + return request.param + + +@pytest.fixture +def env_creator(envs, env_name): + return envs[env_name] diff --git a/tests/policy/__init__.py b/tests/raylab/envs/__init__.py similarity index 100% rename from tests/policy/__init__.py rename to tests/raylab/envs/__init__.py diff --git a/tests/envs/conftest.py b/tests/raylab/envs/conftest.py similarity index 100% rename from tests/envs/conftest.py rename to tests/raylab/envs/conftest.py diff --git a/tests/policy/model_based/__init__.py b/tests/raylab/envs/environments/__init__.py similarity index 100% rename from tests/policy/model_based/__init__.py rename to tests/raylab/envs/environments/__init__.py diff --git a/tests/envs/environments/test_cartpole_swingup.py b/tests/raylab/envs/environments/test_cartpole_swingup.py similarity index 100% rename from tests/envs/environments/test_cartpole_swingup.py rename to tests/raylab/envs/environments/test_cartpole_swingup.py diff --git a/tests/envs/environments/test_hvac.py b/tests/raylab/envs/environments/test_hvac.py similarity index 100% rename from tests/envs/environments/test_hvac.py rename to tests/raylab/envs/environments/test_hvac.py diff --git a/tests/envs/environments/test_navigation.py b/tests/raylab/envs/environments/test_navigation.py similarity index 100% rename from tests/envs/environments/test_navigation.py rename to tests/raylab/envs/environments/test_navigation.py diff --git a/tests/envs/environments/test_reservoir.py b/tests/raylab/envs/environments/test_reservoir.py similarity index 100% rename from tests/envs/environments/test_reservoir.py rename to tests/raylab/envs/environments/test_reservoir.py diff --git a/tests/envs/test_basic.py b/tests/raylab/envs/test_basic.py similarity index 100% rename from tests/envs/test_basic.py rename to tests/raylab/envs/test_basic.py diff --git a/tests/envs/test_gaussian_random_walks.py b/tests/raylab/envs/test_gaussian_random_walks.py similarity index 100% rename from tests/envs/test_gaussian_random_walks.py rename to tests/raylab/envs/test_gaussian_random_walks.py diff --git a/tests/envs/test_rewards.py b/tests/raylab/envs/test_rewards.py similarity index 100% rename from tests/envs/test_rewards.py rename to tests/raylab/envs/test_rewards.py diff --git a/tests/envs/test_termination.py b/tests/raylab/envs/test_termination.py similarity index 100% rename from tests/envs/test_termination.py rename to tests/raylab/envs/test_termination.py diff --git a/tests/pytorch/__init__.py b/tests/raylab/policy/__init__.py similarity index 100% rename from tests/pytorch/__init__.py rename to tests/raylab/policy/__init__.py diff --git a/tests/pytorch/nn/__init__.py b/tests/raylab/policy/losses/__init__.py similarity index 100% rename from tests/pytorch/nn/__init__.py rename to tests/raylab/policy/losses/__init__.py diff --git a/tests/losses/conftest.py b/tests/raylab/policy/losses/conftest.py similarity index 89% rename from tests/losses/conftest.py rename to tests/raylab/policy/losses/conftest.py index a9623a81..11f0de24 100644 --- a/tests/losses/conftest.py +++ b/tests/raylab/policy/losses/conftest.py @@ -5,10 +5,12 @@ import raylab.pytorch.nn as nnx import raylab.pytorch.nn.distributions as ptd -from raylab.modules.mixins.action_value_mixin import ActionValueFunction -from raylab.modules.mixins.deterministic_actor_mixin import DeterministicPolicy -from raylab.modules.mixins.stochastic_actor_mixin import StochasticPolicy -from raylab.modules.mixins.stochastic_model_mixin import StochasticModelMixin +from raylab.policy.modules.v0.mixins.action_value_mixin import ActionValueFunction +from raylab.policy.modules.v0.mixins.deterministic_actor_mixin import ( + DeterministicPolicy, +) +from raylab.policy.modules.v0.mixins.stochastic_actor_mixin import StochasticPolicy +from raylab.policy.modules.v0.mixins.stochastic_model_mixin import StochasticModelMixin from raylab.utils.debug import fake_batch diff --git a/tests/losses/test_cdq_learning.py b/tests/raylab/policy/losses/test_cdq_learning.py similarity index 97% rename from tests/losses/test_cdq_learning.py rename to tests/raylab/policy/losses/test_cdq_learning.py index da70fa34..b170a3aa 100644 --- a/tests/losses/test_cdq_learning.py +++ b/tests/raylab/policy/losses/test_cdq_learning.py @@ -5,7 +5,7 @@ from ray.rllib import SampleBatch import raylab.utils.dictionaries as dutil -from raylab.losses import ClippedDoubleQLearning +from raylab.policy.losses import ClippedDoubleQLearning @pytest.fixture diff --git a/tests/losses/test_mage.py b/tests/raylab/policy/losses/test_mage.py similarity index 96% rename from tests/losses/test_mage.py rename to tests/raylab/policy/losses/test_mage.py index 7ca82577..536a43e9 100644 --- a/tests/losses/test_mage.py +++ b/tests/raylab/policy/losses/test_mage.py @@ -2,8 +2,8 @@ import pytest import torch -from raylab.losses.mage import MAGE -from raylab.losses.mage import MAGEModules +from raylab.policy.losses.mage import MAGE +from raylab.policy.losses.mage import MAGEModules @pytest.fixture diff --git a/tests/losses/test_mapo.py b/tests/raylab/policy/losses/test_mapo.py similarity index 93% rename from tests/losses/test_mapo.py rename to tests/raylab/policy/losses/test_mapo.py index 4332e596..aadb1e72 100644 --- a/tests/losses/test_mapo.py +++ b/tests/raylab/policy/losses/test_mapo.py @@ -2,10 +2,10 @@ import pytest import torch -from raylab.losses import DAPO -from raylab.losses import MAPO -from raylab.losses.abstract import Loss -from raylab.losses.mixins import EnvFunctionsMixin +from raylab.policy.losses import DAPO +from raylab.policy.losses import MAPO +from raylab.policy.losses.abstract import Loss +from raylab.policy.losses.mixins import EnvFunctionsMixin @pytest.fixture diff --git a/tests/losses/test_mle.py b/tests/raylab/policy/losses/test_mle.py similarity index 92% rename from tests/losses/test_mle.py rename to tests/raylab/policy/losses/test_mle.py index fd7365c8..ea6d6d6b 100644 --- a/tests/losses/test_mle.py +++ b/tests/raylab/policy/losses/test_mle.py @@ -2,7 +2,7 @@ import pytest import torch -from raylab.losses import ModelEnsembleMLE +from raylab.policy.losses import ModelEnsembleMLE @pytest.fixture diff --git a/tests/losses/test_paml.py b/tests/raylab/policy/losses/test_paml.py similarity index 96% rename from tests/losses/test_paml.py rename to tests/raylab/policy/losses/test_paml.py index 5c981731..62f1101f 100644 --- a/tests/losses/test_paml.py +++ b/tests/raylab/policy/losses/test_paml.py @@ -3,9 +3,9 @@ import torch from ray.rllib import SampleBatch -from raylab.losses import ModelEnsembleMLE -from raylab.losses import SPAML -from raylab.losses.abstract import Loss +from raylab.policy.losses import ModelEnsembleMLE +from raylab.policy.losses import SPAML +from raylab.policy.losses.abstract import Loss @pytest.fixture diff --git a/tests/pytorch/nn/distributions/__init__.py b/tests/raylab/policy/model_based/__init__.py similarity index 100% rename from tests/pytorch/nn/distributions/__init__.py rename to tests/raylab/policy/model_based/__init__.py diff --git a/tests/raylab/policy/model_based/conftest.py b/tests/raylab/policy/model_based/conftest.py new file mode 100644 index 00000000..209986d9 --- /dev/null +++ b/tests/raylab/policy/model_based/conftest.py @@ -0,0 +1,28 @@ +# pylint:disable=missing-docstring,redefined-outer-name,protected-access +import pytest + +from raylab.policy import TorchPolicy +from raylab.policy.action_dist import BaseActionDist + + +@pytest.fixture(scope="module") +def action_dist(): + class ActionDist(BaseActionDist): + # pylint:disable=abstract-method + @classmethod + def _check_model_compat(cls, *args, **kwargs): + pass + + return ActionDist + + +@pytest.fixture(scope="module") +def base_policy_cls(action_dist, obs_space, action_space): + class Policy(TorchPolicy): + # pylint:disable=abstract-method + dist_class = action_dist + + def __init__(self, config): + super().__init__(obs_space, action_space, config) + + return Policy diff --git a/tests/policy/model_based/test_envfn_mixin.py b/tests/raylab/policy/model_based/test_envfn_mixin.py similarity index 58% rename from tests/policy/model_based/test_envfn_mixin.py rename to tests/raylab/policy/model_based/test_envfn_mixin.py index e7be8810..6e0dde75 100644 --- a/tests/policy/model_based/test_envfn_mixin.py +++ b/tests/raylab/policy/model_based/test_envfn_mixin.py @@ -1,27 +1,28 @@ # pylint:disable=missing-docstring,redefined-outer-name,protected-access import math -from unittest import mock import pytest import torch from raylab.envs import get_reward_fn -from raylab.envs import get_termination_fn from raylab.policy import EnvFnMixin -from raylab.policy import TorchPolicy from raylab.utils.debug import fake_space_samples -class DummyPolicy(EnvFnMixin, TorchPolicy): - # pylint:disable=all - @staticmethod - def get_default_config(): - return {"module": {"type": "OnPolicyActorCritic"}} +@pytest.fixture +def policy_cls(base_policy_cls): + class Policy(EnvFnMixin, base_policy_cls): + @staticmethod + def get_default_config(): + return {"module": {"type": "OnPolicyActorCritic"}} + + return Policy @pytest.fixture def reward_fn(): - def func(obs, act, new_obs): + def func(*args): + act = args[1] return act.norm(p=1, dim=-1) return func @@ -29,7 +30,7 @@ def func(obs, act, new_obs): @pytest.fixture def termination_fn(): - def func(obs, act, new_obs): + def func(obs, *_): return torch.randn(obs.shape[:-1]) > 0 return func @@ -37,7 +38,7 @@ def func(obs, act, new_obs): @pytest.fixture def dynamics_fn(): - def func(obs, act): + def func(obs, _): sample = torch.randn_like(obs) log_prob = torch.sum( -(sample ** 2) / 2 @@ -51,8 +52,8 @@ def func(obs, act): @pytest.fixture -def policy(obs_space, action_space): - return DummyPolicy(obs_space, action_space, {}) +def policy(policy_cls): + return policy_cls({}) def test_init(policy): @@ -61,7 +62,8 @@ def test_init(policy): assert hasattr(policy, "dynamics_fn") -def test_set_reward_from_config(policy, envs): # pylint:disable=unused-argument +def test_set_reward_from_config(policy, mocker): + obs_space, action_space = policy.observation_space, policy.action_space batch_size = 10 obs = fake_space_samples(obs_space, batch_size=batch_size) @@ -69,9 +71,9 @@ def test_set_reward_from_config(policy, envs): # pylint:disable=unused-argument new_obs = fake_space_samples(obs_space, batch_size=batch_size) obs, act, new_obs = map(policy.convert_to_tensor, (obs, act, new_obs)) - with mock.patch("raylab.policy.EnvFnMixin._set_reward_hook") as hook: - policy.set_reward_from_config("MockEnv", {}) - assert hook.called + hook = mocker.spy(EnvFnMixin, "_set_reward_hook") + policy.set_reward_from_config("MockEnv", {}) + assert hook.called original_fn = get_reward_fn("MockEnv", {}) expected_rew = original_fn(obs, act, new_obs) @@ -80,7 +82,7 @@ def test_set_reward_from_config(policy, envs): # pylint:disable=unused-argument assert torch.allclose(rew, expected_rew) -def test_set_termination_from_config(policy, envs): # pylint:disable=unused-argument +def test_set_termination_from_config(policy, mocker): obs_space, action_space = policy.observation_space, policy.action_space batch_size = 10 obs = fake_space_samples(obs_space, batch_size=batch_size) @@ -88,9 +90,9 @@ def test_set_termination_from_config(policy, envs): # pylint:disable=unused-arg new_obs = fake_space_samples(obs_space, batch_size=batch_size) obs, act, new_obs = map(policy.convert_to_tensor, (obs, act, new_obs)) - with mock.patch("raylab.policy.EnvFnMixin._set_termination_hook") as hook: - policy.set_termination_from_config("MockEnv", {}) - assert hook.called + hook = mocker.spy(EnvFnMixin, "_set_termination_hook") + policy.set_termination_from_config("MockEnv", {}) + assert hook.called done = policy.termination_fn(obs, act, new_obs) assert torch.is_tensor(done) @@ -98,28 +100,28 @@ def test_set_termination_from_config(policy, envs): # pylint:disable=unused-arg assert done.shape == obs.shape[:-1] -def test_set_reward_from_callable(policy, reward_fn): - with mock.patch("raylab.policy.EnvFnMixin._set_reward_hook") as hook: - policy.set_reward_from_callable(reward_fn) - assert hook.called +def test_set_reward_from_callable(policy, reward_fn, mocker): + hook = mocker.spy(EnvFnMixin, "_set_reward_hook") + policy.set_reward_from_callable(reward_fn) + assert hook.called assert hasattr(policy, "reward_fn") assert policy.reward_fn is reward_fn -def test_set_termination_from_callable(policy, termination_fn): - with mock.patch("raylab.policy.EnvFnMixin._set_termination_hook") as hook: - policy.set_termination_from_callable(termination_fn) - assert hook.called +def test_set_termination_from_callable(policy, termination_fn, mocker): + hook = mocker.spy(EnvFnMixin, "_set_termination_hook") + policy.set_termination_from_callable(termination_fn) + assert hook.called assert hasattr(policy, "termination_fn") assert policy.termination_fn is termination_fn -def test_set_dynamics_from_callable(policy, dynamics_fn): - with mock.patch("raylab.policy.EnvFnMixin._set_dynamics_hook") as hook: - policy.set_dynamics_from_callable(dynamics_fn) - assert hook.called +def test_set_dynamics_from_callable(policy, dynamics_fn, mocker): + hook = mocker.spy(EnvFnMixin, "_set_dynamics_hook") + policy.set_dynamics_from_callable(dynamics_fn) + assert hook.called assert hasattr(policy, "dynamics_fn") assert policy.dynamics_fn is dynamics_fn diff --git a/tests/policy/model_based/test_sampling_mixin.py b/tests/raylab/policy/model_based/test_sampling_mixin.py similarity index 80% rename from tests/policy/model_based/test_sampling_mixin.py rename to tests/raylab/policy/model_based/test_sampling_mixin.py index 0b535603..31479549 100644 --- a/tests/policy/model_based/test_sampling_mixin.py +++ b/tests/raylab/policy/model_based/test_sampling_mixin.py @@ -15,32 +15,32 @@ ROLLOUT_SCHEDULE = ([(0, 1), (200, 10)], [(7, 2)]) -class DummyPolicy(ModelSamplingMixin, TorchPolicy): - # pylint:disable=all - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +@pytest.fixture(scope="module") +def policy_cls(base_policy_cls): + class Policy(ModelSamplingMixin, base_policy_cls): + # pylint:disable=all - def reward_fn(obs, act, new_obs): - return act.norm(p=1, dim=-1) + def __init__(self, config): + super().__init__(config) - def termination_fn(obs, act, new_obs): - return torch.randn(obs.shape[:-1]) > 0 + def reward_fn(obs, act, new_obs): + return act.norm(p=1, dim=-1) - self.reward_fn = reward_fn - self.termination_fn = termination_fn + def termination_fn(obs, act, new_obs): + return torch.randn(obs.shape[:-1]) > 0 - @staticmethod - def get_default_config(): - return { - "model_sampling": ModelSamplingMixin.model_sampling_defaults(), - "module": {"type": "ModelBasedSAC"}, - "seed": None, - } + self.reward_fn = reward_fn + self.termination_fn = termination_fn + @staticmethod + def get_default_config(): + return { + "model_sampling": ModelSamplingMixin.model_sampling_defaults(), + "module": {"type": "ModelBasedSAC"}, + "seed": None, + } -@pytest.fixture(scope="module") -def policy_cls(obs_space, action_space): - return functools.partial(DummyPolicy, obs_space, action_space) + return Policy @pytest.fixture( diff --git a/tests/policy/model_based/test_training_mixin.py b/tests/raylab/policy/model_based/test_training_mixin.py similarity index 82% rename from tests/policy/model_based/test_training_mixin.py rename to tests/raylab/policy/model_based/test_training_mixin.py index b2310960..7a1d4d4f 100644 --- a/tests/policy/model_based/test_training_mixin.py +++ b/tests/raylab/policy/model_based/test_training_mixin.py @@ -8,7 +8,6 @@ from raylab.policy import ModelTrainingMixin from raylab.policy import OptimizerCollection -from raylab.policy import TorchPolicy from raylab.policy.model_based.training_mixin import Evaluator from raylab.pytorch.optim import build_optimizer from raylab.utils.debug import fake_batch @@ -24,25 +23,6 @@ def __call__(self, _): return losses, {"loss(models)": losses.mean().item()} -class DummyPolicy(ModelTrainingMixin, TorchPolicy): - # pylint:disable=abstract-method - def __init__(self, observation_space, action_space, config): - super().__init__(observation_space, action_space, config) - loss = DummyLoss() - loss.ensemble_size = len(self.module.models) - self.loss_model = loss - - @staticmethod - def get_default_config(): - return { - "model_training": ModelTrainingMixin.model_training_defaults(), - "module": {"type": "ModelBasedSAC"}, - } - - def make_optimizers(self): - return {"models": build_optimizer(self.module.models, {"type": "Adam"})} - - @pytest.fixture def train_samples(obs_space, action_space): return fake_batch(obs_space, action_space, batch_size=80) @@ -54,8 +34,26 @@ def eval_samples(obs_space, action_space): @pytest.fixture(scope="module") -def policy_cls(obs_space, action_space): - return functools.partial(DummyPolicy, obs_space, action_space) +def policy_cls(base_policy_cls): + class Policy(ModelTrainingMixin, base_policy_cls): + # pylint:disable=abstract-method + def __init__(self, config): + super().__init__(config) + loss = DummyLoss() + loss.ensemble_size = len(self.module.models) + self.loss_model = loss + + @staticmethod + def get_default_config(): + return { + "model_training": ModelTrainingMixin.model_training_defaults(), + "module": {"type": "ModelBasedSAC"}, + } + + def make_optimizers(self): + return {"models": build_optimizer(self.module.models, {"type": "Adam"})} + + return Policy @pytest.fixture(scope="module", params=(1, 4), ids=lambda s: f"Ensemble({s})") diff --git a/tests/pytorch/nn/distributions/flows/__init__.py b/tests/raylab/policy/modules/__init__.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/__init__.py rename to tests/raylab/policy/modules/__init__.py diff --git a/tests/pytorch/nn/modules/__init__.py b/tests/raylab/policy/modules/actor/__init__.py similarity index 100% rename from tests/pytorch/nn/modules/__init__.py rename to tests/raylab/policy/modules/actor/__init__.py diff --git a/tests/raylab/policy/modules/actor/conftest.py b/tests/raylab/policy/modules/actor/conftest.py new file mode 100644 index 00000000..09782b58 --- /dev/null +++ b/tests/raylab/policy/modules/actor/conftest.py @@ -0,0 +1,45 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from gym.spaces import Box +from gym.spaces import Discrete + +from raylab.utils.debug import fake_batch + + +DISC_SPACES = (Discrete(2), Discrete(8)) +CONT_SPACES = (Box(-1, 1, shape=(1,)), Box(-1, 1, shape=(3,))) +ACTION_SPACES = CONT_SPACES + DISC_SPACES + + +@pytest.fixture(params=DISC_SPACES, ids=(repr(a) for a in DISC_SPACES)) +def disc_space(request): + return request.param + + +@pytest.fixture(params=CONT_SPACES, ids=(repr(a) for a in CONT_SPACES)) +def cont_space(request): + return request.param + + +@pytest.fixture(params=ACTION_SPACES, ids=(repr(a) for a in ACTION_SPACES)) +def action_space(request): + return request.param + + +@pytest.fixture +def disc_batch(obs_space, disc_space): + samples = fake_batch(obs_space, disc_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} + + +@pytest.fixture +def cont_batch(obs_space, cont_space): + samples = fake_batch(obs_space, cont_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} + + +@pytest.fixture +def batch(obs_space, action_space): + samples = fake_batch(obs_space, action_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} diff --git a/tests/utils/__init__.py b/tests/raylab/policy/modules/actor/policy/__init__.py similarity index 100% rename from tests/utils/__init__.py rename to tests/raylab/policy/modules/actor/policy/__init__.py diff --git a/tests/raylab/policy/modules/actor/policy/test_deterministic.py b/tests/raylab/policy/modules/actor/policy/test_deterministic.py new file mode 100644 index 00000000..528e1ef1 --- /dev/null +++ b/tests/raylab/policy/modules/actor/policy/test_deterministic.py @@ -0,0 +1,41 @@ +# pylint:disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.policy.modules.actor.policy.deterministic import MLPDeterministicPolicy + + return MLPDeterministicPolicy + + +@pytest.fixture(params=(0.1, 1.2), ids=lambda x: f"NormBeta({x})") +def norm_beta(request): + return request.param + + +@pytest.fixture +def spec(module_cls): + return module_cls.spec_cls() + + +@pytest.fixture +def action_space(cont_space): + return cont_space + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec, norm_beta): + return module_cls(obs_space, action_space, spec, norm_beta) + + +def test_unconstrained_action(module, batch, action_space, norm_beta): + action_dim = action_space.shape[0] + + policy_out = module.unconstrained_action(batch[SampleBatch.CUR_OBS]) + norms = policy_out.norm(p=1, dim=-1, keepdim=True) / action_dim + assert policy_out.shape[-1] == action_dim + assert policy_out.dtype == torch.float32 + assert (norms <= (norm_beta + torch.finfo(torch.float32).eps)).all() diff --git a/tests/raylab/policy/modules/actor/policy/test_stochastic.py b/tests/raylab/policy/modules/actor/policy/test_stochastic.py new file mode 100644 index 00000000..1cc8da67 --- /dev/null +++ b/tests/raylab/policy/modules/actor/policy/test_stochastic.py @@ -0,0 +1,128 @@ +# pylint:disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module") +def base_cls(): + from raylab.policy.modules.actor.policy.stochastic import MLPStochasticPolicy + + return MLPStochasticPolicy + + +@pytest.fixture(scope="module") +def cont_cls(): + from raylab.policy.modules.actor.policy.stochastic import MLPContinuousPolicy + + return MLPContinuousPolicy + + +@pytest.fixture(scope="module") +def disc_cls(): + from raylab.policy.modules.actor.policy.stochastic import MLPDiscretePolicy + + return MLPDiscretePolicy + + +@pytest.fixture +def spec(base_cls): + return base_cls.spec_cls() + + +@pytest.fixture(params=(True, False), ids=lambda x: f"InputDependentScale({x})") +def input_dependent_scale(request): + return request.param + + +@pytest.fixture +def cont_policy(cont_cls, obs_space, cont_space, spec, input_dependent_scale): + return cont_cls(obs_space, cont_space, spec, input_dependent_scale) + + +@pytest.fixture +def disc_policy(disc_cls, obs_space, disc_space, spec): + return disc_cls(obs_space, disc_space, spec) + + +def test_continuous_sample(cont_policy, cont_batch): + policy, batch = cont_policy, cont_batch + action = batch[SampleBatch.ACTIONS] + + sampler = policy.rsample + samples, logp = sampler(batch[SampleBatch.CUR_OBS]) + samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == action.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_discrete_sample(disc_policy, disc_batch): + policy, batch = disc_policy, disc_batch + action = batch[SampleBatch.ACTIONS] + + sampler = policy.sample + samples, logp = sampler(batch[SampleBatch.CUR_OBS]) + samples_, _ = sampler(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == action.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_continuous_params(cont_policy, cont_batch): + policy, batch = cont_policy, cont_batch + params = policy(batch[SampleBatch.CUR_OBS]) + assert "loc" in params + assert "scale" in params + + loc, scale = params["loc"], params["scale"] + action = batch[SampleBatch.ACTIONS] + assert loc.shape == action.shape + assert scale.shape == action.shape + assert loc.dtype == torch.float32 + assert scale.dtype == torch.float32 + + pi_params = set(policy.parameters()) + for par in pi_params: + par.grad = None + loc.mean().backward() + assert any(p.grad is not None for p in pi_params) + + for par in pi_params: + par.grad = None + policy(batch[SampleBatch.CUR_OBS])["scale"].mean().backward() + assert any(p.grad is not None for p in pi_params) + + +def test_discrete_params(disc_policy, disc_space, disc_batch): + policy, batch = disc_policy, disc_batch + + params = policy(batch[SampleBatch.CUR_OBS]) + assert "logits" in params + logits = params["logits"] + assert logits.shape[-1] == disc_space.n + + pi_params = set(policy.parameters()) + for par in pi_params: + par.grad = None + logits.mean().backward() + assert any(p.grad is not None for p in pi_params) + + +def test_reproduce(cont_policy, cont_batch): + policy, batch = cont_policy, cont_batch + + acts = batch[SampleBatch.ACTIONS] + acts_, logp_ = policy.reproduce(batch[SampleBatch.CUR_OBS], acts) + assert acts_.shape == acts.shape + assert acts_.dtype == acts.dtype + assert torch.allclose(acts_, acts, atol=1e-5) + assert logp_.shape == batch[SampleBatch.REWARDS].shape + + acts_.mean().backward() + pi_params = set(policy.parameters()) + assert all(p.grad is not None for p in pi_params) diff --git a/tests/raylab/policy/modules/actor/test_deterministic.py b/tests/raylab/policy/modules/actor/test_deterministic.py new file mode 100644 index 00000000..eae71899 --- /dev/null +++ b/tests/raylab/policy/modules/actor/test_deterministic.py @@ -0,0 +1,84 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture +def action_space(cont_space): + return cont_space + + +@pytest.fixture +def batch(cont_batch): + return cont_batch + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.policy.modules.actor.deterministic import DeterministicActor + + return DeterministicActor + + +@pytest.fixture(params=(True, False), ids=lambda x: f"SeparateTargetPolicy({x})") +def separate_target_policy(request): + return request.param + + +@pytest.fixture(params=(True, False), ids=lambda x: f"SeparateBehavior({x})") +def separate_behavior(request): + return request.param + + +@pytest.fixture +def spec(module_cls, separate_behavior, separate_target_policy): + return module_cls.spec_cls( + separate_behavior=separate_behavior, + separate_target_policy=separate_target_policy, + ) + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec): + return module_cls(obs_space, action_space, spec) + + +def test_module_creation(module): + for attr in "policy behavior target_policy".split(): + assert hasattr(module, attr) + + policy, target_policy = module.policy, module.target_policy + assert all( + torch.allclose(p, p_) + for p, p_ in zip(policy.parameters(), target_policy.parameters()) + ) + + +def test_separate_behavior(module_cls, obs_space, action_space): + spec = module_cls.spec_cls(separate_behavior=True) + module = module_cls(obs_space, action_space, spec) + + assert all( + torch.allclose(p, n) + for p, n in zip(module.policy.parameters(), module.behavior.parameters()) + ) + + +def test_separate_target_policy(module, spec): + policy, target = module.policy, module.target_policy + + if spec.separate_target_policy: + assert all(p is not t for p, t in zip(policy.parameters(), target.parameters())) + else: + assert all(p is t for p, t in zip(policy.parameters(), target.parameters())) + + +def test_behavior(module, batch): + action = batch[SampleBatch.ACTIONS] + + samples = module.behavior(batch[SampleBatch.CUR_OBS]) + samples_ = module.behavior(batch[SampleBatch.CUR_OBS]) + assert samples.shape == action.shape + assert samples.dtype == torch.float32 + assert torch.allclose(samples, samples_) diff --git a/tests/raylab/policy/modules/actor/test_stochastic.py b/tests/raylab/policy/modules/actor/test_stochastic.py new file mode 100644 index 00000000..8ebf382b --- /dev/null +++ b/tests/raylab/policy/modules/actor/test_stochastic.py @@ -0,0 +1,48 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.policy.modules.actor.stochastic import StochasticActor + + return StochasticActor + + +@pytest.fixture(params=(True, False), ids=lambda x: f"InputDependentScale({x})") +def input_dependent_scale(request): + return request.param + + +@pytest.fixture +def cont_spec(module_cls, input_dependent_scale): + return module_cls.spec_cls(input_dependent_scale=input_dependent_scale) + + +@pytest.fixture +def spec(module_cls): + return module_cls.spec_cls() + + +@pytest.fixture +def disc_module(module_cls, obs_space, disc_space, spec, torch_script): + mod = module_cls(obs_space, disc_space, spec) + return torch.jit.script(mod) if torch_script else mod + + +@pytest.fixture +def cont_module(module_cls, obs_space, cont_space, spec, torch_script): + mod = module_cls(obs_space, cont_space, spec) + return torch.jit.script(mod) if torch_script else mod + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec, torch_script): + mod = module_cls(obs_space, action_space, spec) + return torch.jit.script(mod) if torch_script else mod + + +def test_init(module): + for attr in "policy alpha".split(): + assert hasattr(module, attr) diff --git a/tests/raylab/policy/modules/conftest.py b/tests/raylab/policy/modules/conftest.py new file mode 100644 index 00000000..3156cba9 --- /dev/null +++ b/tests/raylab/policy/modules/conftest.py @@ -0,0 +1,20 @@ +# pylint:disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch + + +@pytest.fixture( + params=(pytest.param(True, marks=pytest.mark.slow), False), + ids=("TorchScript", "Eager"), + scope="module", +) +def torch_script(request): + return request.param + + +@pytest.fixture(scope="module") +def batch(obs_space, action_space): + from raylab.utils.debug import fake_batch + + samples = fake_batch(obs_space, action_space, batch_size=32) + return {k: torch.from_numpy(v) for k, v in samples.items()} diff --git a/tests/raylab/policy/modules/critic/__init__.py b/tests/raylab/policy/modules/critic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/raylab/policy/modules/critic/test_action_value.py b/tests/raylab/policy/modules/critic/test_action_value.py new file mode 100644 index 00000000..3a2f7976 --- /dev/null +++ b/tests/raylab/policy/modules/critic/test_action_value.py @@ -0,0 +1,59 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module") +def module_cls(): + from raylab.policy.modules.critic.action_value import ActionValueCritic + + return ActionValueCritic + + +@pytest.fixture(params=(True, False), ids="DoubleQ SingleQ".split()) +def double_q(request): + return request.param + + +@pytest.fixture(params=(True, False), ids=lambda x: "Parallelize({x})") +def parallelize(request): + return request.param + + +@pytest.fixture +def spec(module_cls, double_q, parallelize): + return module_cls.spec_cls(double_q=double_q, parallelize=parallelize) + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec): + return module_cls(obs_space, action_space, spec) + + +def test_module_creation(module, batch, spec): + double_q = spec.double_q + + for attr in "q_values target_q_values".split(): + assert hasattr(module, attr) + expected_n_critics = 2 if double_q else 1 + assert len(module.q_values) == expected_n_critics + + q_values, targets = module.q_values, module.target_q_values + vals = [ + m(batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS]) + for ensemble in (q_values, targets) + for m in ensemble + ] + for val in vals: + assert val.shape[-1] == 1 + assert val.dtype == torch.float32 + + assert all( + torch.allclose(p, t) + for p, t in zip(q_values.parameters(), targets.parameters()) + ) + + +def test_script(module): + torch.jit.script(module) diff --git a/tests/raylab/policy/modules/model/__init__.py b/tests/raylab/policy/modules/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/raylab/policy/modules/model/stochastic/__init__.py b/tests/raylab/policy/modules/model/stochastic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/raylab/policy/modules/model/stochastic/conftest.py b/tests/raylab/policy/modules/model/stochastic/conftest.py new file mode 100644 index 00000000..f544d2cc --- /dev/null +++ b/tests/raylab/policy/modules/model/stochastic/conftest.py @@ -0,0 +1,16 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +from ray.rllib import SampleBatch + + +@pytest.fixture +def log_prob_inputs(batch): + return [ + batch[k] + for k in (SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS) + ] + + +@pytest.fixture +def sample_inputs(batch): + return [batch[k] for k in (SampleBatch.CUR_OBS, SampleBatch.ACTIONS)] diff --git a/tests/raylab/policy/modules/model/stochastic/test_ensemble.py b/tests/raylab/policy/modules/model/stochastic/test_ensemble.py new file mode 100644 index 00000000..0ee8ae0e --- /dev/null +++ b/tests/raylab/policy/modules/model/stochastic/test_ensemble.py @@ -0,0 +1,44 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch + + +@pytest.fixture(scope="module", params=(True, False), ids=lambda x: f"Forked({x})") +def module_cls(request): + from raylab.policy.modules.model.stochastic.ensemble import StochasticModelEnsemble + from raylab.policy.modules.model.stochastic.ensemble import ( + ForkedStochasticModelEnsemble, + ) + + return ForkedStochasticModelEnsemble if request.param else StochasticModelEnsemble + + +@pytest.fixture(params=(1, 4), ids=lambda x: f"Ensemble({x})") +def ensemble_size(request): + return request.param + + +@pytest.fixture +def build_single(obs_space, action_space): + from raylab.policy.modules.model.stochastic.single import MLPModel + + spec = MLPModel.spec_cls() + input_dependent_scale = True + + return lambda: MLPModel(obs_space, action_space, spec, input_dependent_scale) + + +@pytest.fixture +def module(module_cls, build_single, ensemble_size, torch_script): + models = [build_single() for _ in range(ensemble_size)] + + module = module_cls(models) + return torch.jit.script(module) if torch_script else module + + +def test_log_prob(module, log_prob_inputs, ensemble_size): + obs = log_prob_inputs[0] + log_prob = module.log_prob(*log_prob_inputs) + + assert torch.is_tensor(log_prob) + assert log_prob.shape == (ensemble_size,) + obs.shape[:-1] diff --git a/tests/raylab/policy/modules/model/stochastic/test_single.py b/tests/raylab/policy/modules/model/stochastic/test_single.py new file mode 100644 index 00000000..d4a25239 --- /dev/null +++ b/tests/raylab/policy/modules/model/stochastic/test_single.py @@ -0,0 +1,102 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +from ray.rllib import SampleBatch + + +@pytest.fixture(scope="module", params=(True, False), ids=lambda x: f"Residual({x})") +def module_cls(request): + from raylab.policy.modules.model.stochastic.single import MLPModel + from raylab.policy.modules.model.stochastic.single import ResidualMLPModel + + return ResidualMLPModel if request.param else MLPModel + + +@pytest.fixture +def spec(module_cls): + return module_cls.spec_cls() + + +@pytest.fixture(params=(True, False), ids=lambda x: f"InputDependentScale({x})") +def input_dependent_scale(request): + return request.param + + +@pytest.fixture +def module(module_cls, obs_space, action_space, spec, input_dependent_scale): + return module_cls(obs_space, action_space, spec, input_dependent_scale) + + +def test_sample(module, batch): + new_obs = batch[SampleBatch.NEXT_OBS] + sampler = module.rsample + inputs = (batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS]) + + samples, logp = sampler(*inputs) + samples_, _ = sampler(*inputs) + assert samples.shape == new_obs.shape + assert samples.dtype == new_obs.dtype + assert logp.shape == batch[SampleBatch.REWARDS].shape + assert logp.dtype == batch[SampleBatch.REWARDS].dtype + assert not torch.allclose(samples, samples_) + + +def test_params(module, batch): + inputs = (batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS]) + new_obs = batch[SampleBatch.NEXT_OBS] + + params = module(*inputs) + assert "loc" in params + assert "scale" in params + + loc, scale = params["loc"], params["scale"] + assert loc.shape == new_obs.shape + assert scale.shape == new_obs.shape + assert loc.dtype == torch.float32 + assert scale.dtype == torch.float32 + + params = set(module.parameters()) + for par in params: + par.grad = None + loc.mean().backward() + assert any(p.grad is not None for p in params) + + for par in params: + par.grad = None + module(*inputs)["scale"].mean().backward() + assert any(p.grad is not None for p in params) + + +def test_log_prob(module, batch): + logp = module.log_prob( + batch[SampleBatch.CUR_OBS], + batch[SampleBatch.ACTIONS], + batch[SampleBatch.NEXT_OBS], + ) + + assert torch.is_tensor(logp) + assert logp.shape == batch[SampleBatch.REWARDS].shape + + logp.sum().backward() + assert all(p.grad is not None for p in module.parameters()) + + +def test_reproduce(module, batch): + obs, act, new_obs = [ + batch[k] + for k in (SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS) + ] + + new_obs_, logp_ = module.reproduce(obs, act, new_obs) + assert new_obs_.shape == new_obs.shape + assert new_obs_.dtype == new_obs.dtype + assert torch.allclose(new_obs_, new_obs, atol=1e-5) + assert logp_.shape == batch[SampleBatch.REWARDS].shape + + new_obs_.mean().backward() + params = set(module.parameters()) + assert all(p.grad is not None for p in params) + + +def test_script(module): + torch.jit.script(module) diff --git a/tests/raylab/policy/modules/networks/__init__.py b/tests/raylab/policy/modules/networks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/networks/test_mlp.py b/tests/raylab/policy/modules/networks/test_mlp.py similarity index 90% rename from tests/modules/networks/test_mlp.py rename to tests/raylab/policy/modules/networks/test_mlp.py index c7c01d98..71ff0a0f 100644 --- a/tests/modules/networks/test_mlp.py +++ b/tests/raylab/policy/modules/networks/test_mlp.py @@ -2,7 +2,7 @@ import pytest import torch -from raylab.modules.networks import MLP +from raylab.policy.modules.networks import MLP PARAMS = (None, {}, {"state": torch.randn(10, 4)}) diff --git a/tests/modules/networks/test_resnet.py b/tests/raylab/policy/modules/networks/test_resnet.py similarity index 89% rename from tests/modules/networks/test_resnet.py rename to tests/raylab/policy/modules/networks/test_resnet.py index b64d04bd..77d03d9d 100644 --- a/tests/modules/networks/test_resnet.py +++ b/tests/raylab/policy/modules/networks/test_resnet.py @@ -2,7 +2,7 @@ import pytest import torch -from raylab.modules.networks import ResidualNet +from raylab.policy.modules.networks import ResidualNet PARAMS = (None, {}, {"state": torch.randn(10, 4)}) diff --git a/tests/raylab/policy/modules/test_ddpg.py b/tests/raylab/policy/modules/test_ddpg.py new file mode 100644 index 00000000..7cd2c53d --- /dev/null +++ b/tests/raylab/policy/modules/test_ddpg.py @@ -0,0 +1,34 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +import torch.nn as nn + +from raylab.policy.modules.ddpg import DDPG + + +@pytest.fixture +def spec_cls(): + return DDPG.spec_cls + + +@pytest.fixture +def module(obs_space, action_space, spec_cls): + return DDPG(obs_space, action_space, spec_cls()) + + +def test_spec(spec_cls): + default_config = spec_cls().to_dict() + + for key in ["actor", "critic", "initializer"]: + assert key in default_config + + +def test_init(module): + assert isinstance(module, nn.Module) + + for attr in "actor behavior target_actor critics target_critics".split(): + assert hasattr(module, attr) + + +def test_script(module): + torch.jit.script(module) diff --git a/tests/raylab/policy/modules/test_sac.py b/tests/raylab/policy/modules/test_sac.py new file mode 100644 index 00000000..aa3f611f --- /dev/null +++ b/tests/raylab/policy/modules/test_sac.py @@ -0,0 +1,34 @@ +# pylint: disable=missing-docstring,redefined-outer-name,protected-access +import pytest +import torch +import torch.nn as nn + +from raylab.policy.modules.sac import SAC + + +@pytest.fixture +def spec_cls(): + return SAC.spec_cls + + +@pytest.fixture +def module(obs_space, action_space, spec_cls): + return SAC(obs_space, action_space, spec_cls()) + + +def test_spec(spec_cls): + default_config = spec_cls().to_dict() + + for key in ["actor", "critic", "initializer"]: + assert key in default_config + + +def test_init(module): + assert isinstance(module, nn.Module) + + for attr in "actor alpha critics target_critics".split(): + assert hasattr(module, attr) + + +def test_script(module): + torch.jit.script(module) diff --git a/tests/raylab/policy/modules/v0/__init__.py b/tests/raylab/policy/modules/v0/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/modules/conftest.py b/tests/raylab/policy/modules/v0/conftest.py similarity index 92% rename from tests/modules/conftest.py rename to tests/raylab/policy/modules/v0/conftest.py index c8fcbee4..77f8c204 100644 --- a/tests/modules/conftest.py +++ b/tests/raylab/policy/modules/v0/conftest.py @@ -1,4 +1,4 @@ -# pylint: disable=missing-docstring,redefined-outer-name,protected-access +# pylint:disable=missing-docstring,redefined-outer-name,protected-access from functools import partial import pytest diff --git a/tests/modules/test_action_value_mixin.py b/tests/raylab/policy/modules/v0/test_action_value_mixin.py similarity index 96% rename from tests/modules/test_action_value_mixin.py rename to tests/raylab/policy/modules/v0/test_action_value_mixin.py index b5a3e97e..18af7f92 100644 --- a/tests/modules/test_action_value_mixin.py +++ b/tests/raylab/policy/modules/v0/test_action_value_mixin.py @@ -6,7 +6,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.mixins import ActionValueMixin +from raylab.policy.modules.v0.mixins import ActionValueMixin class DummyModule(ActionValueMixin, nn.ModuleDict): diff --git a/tests/modules/test_deterministic_actor_mixin.py b/tests/raylab/policy/modules/v0/test_deterministic_actor_mixin.py similarity index 98% rename from tests/modules/test_deterministic_actor_mixin.py rename to tests/raylab/policy/modules/v0/test_deterministic_actor_mixin.py index db583b75..0297762b 100644 --- a/tests/modules/test_deterministic_actor_mixin.py +++ b/tests/raylab/policy/modules/v0/test_deterministic_actor_mixin.py @@ -5,7 +5,7 @@ from ray.rllib import SampleBatch from ray.rllib.utils import merge_dicts -from raylab.modules.mixins import DeterministicActorMixin +from raylab.policy.modules.v0.mixins import DeterministicActorMixin BASE_CONFIG = { diff --git a/tests/modules/test_naf_module.py b/tests/raylab/policy/modules/v0/test_naf_module.py similarity index 91% rename from tests/modules/test_naf_module.py rename to tests/raylab/policy/modules/v0/test_naf_module.py index ae1b80aa..a5ec2c51 100644 --- a/tests/modules/test_naf_module.py +++ b/tests/raylab/policy/modules/v0/test_naf_module.py @@ -4,7 +4,7 @@ import pytest import torch -from raylab.modules.naf_module import NAFModule +from raylab.policy.modules.v0.naf_module import NAFModule @pytest.fixture(params=(True, False), ids=("Double Q", "Single Q")) diff --git a/tests/modules/test_normalizing_flow_actor_mixin.py b/tests/raylab/policy/modules/v0/test_normalizing_flow_actor_mixin.py similarity index 98% rename from tests/modules/test_normalizing_flow_actor_mixin.py rename to tests/raylab/policy/modules/v0/test_normalizing_flow_actor_mixin.py index 856d686d..e8648b81 100644 --- a/tests/modules/test_normalizing_flow_actor_mixin.py +++ b/tests/raylab/policy/modules/v0/test_normalizing_flow_actor_mixin.py @@ -6,7 +6,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.mixins import NormalizingFlowActorMixin +from raylab.policy.modules.v0.mixins import NormalizingFlowActorMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/test_normalizing_flow_model_mixin.py b/tests/raylab/policy/modules/v0/test_normalizing_flow_model_mixin.py similarity index 98% rename from tests/modules/test_normalizing_flow_model_mixin.py rename to tests/raylab/policy/modules/v0/test_normalizing_flow_model_mixin.py index 5a2d6386..534c6cbd 100644 --- a/tests/modules/test_normalizing_flow_model_mixin.py +++ b/tests/raylab/policy/modules/v0/test_normalizing_flow_model_mixin.py @@ -7,7 +7,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.mixins import NormalizingFlowModelMixin +from raylab.policy.modules.v0.mixins import NormalizingFlowModelMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/test_state_value_mixin.py b/tests/raylab/policy/modules/v0/test_state_value_mixin.py similarity index 96% rename from tests/modules/test_state_value_mixin.py rename to tests/raylab/policy/modules/v0/test_state_value_mixin.py index 229ea3fc..b30368c5 100644 --- a/tests/modules/test_state_value_mixin.py +++ b/tests/raylab/policy/modules/v0/test_state_value_mixin.py @@ -6,7 +6,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.mixins import StateValueMixin +from raylab.policy.modules.v0.mixins import StateValueMixin class DummyModule(StateValueMixin, nn.ModuleDict): diff --git a/tests/modules/test_stochastic_actor_mixin.py b/tests/raylab/policy/modules/v0/test_stochastic_actor_mixin.py similarity index 98% rename from tests/modules/test_stochastic_actor_mixin.py rename to tests/raylab/policy/modules/v0/test_stochastic_actor_mixin.py index 65025e60..4030da2e 100644 --- a/tests/modules/test_stochastic_actor_mixin.py +++ b/tests/raylab/policy/modules/v0/test_stochastic_actor_mixin.py @@ -7,7 +7,7 @@ from gym.spaces import Discrete from ray.rllib import SampleBatch -from raylab.modules.mixins import StochasticActorMixin +from raylab.policy.modules.v0.mixins import StochasticActorMixin from .utils import make_batch from .utils import make_module diff --git a/tests/modules/test_stochastic_model_mixin.py b/tests/raylab/policy/modules/v0/test_stochastic_model_mixin.py similarity index 98% rename from tests/modules/test_stochastic_model_mixin.py rename to tests/raylab/policy/modules/v0/test_stochastic_model_mixin.py index ba2fe89f..faaeb3c1 100644 --- a/tests/modules/test_stochastic_model_mixin.py +++ b/tests/raylab/policy/modules/v0/test_stochastic_model_mixin.py @@ -4,7 +4,7 @@ import torch.nn as nn from ray.rllib import SampleBatch -from raylab.modules.mixins import StochasticModelMixin +from raylab.policy.modules.v0.mixins import StochasticModelMixin class DummyModule(StochasticModelMixin, nn.ModuleDict): diff --git a/tests/modules/test_svg_module.py b/tests/raylab/policy/modules/v0/test_svg_module.py similarity index 95% rename from tests/modules/test_svg_module.py rename to tests/raylab/policy/modules/v0/test_svg_module.py index 9595b078..3271a045 100644 --- a/tests/modules/test_svg_module.py +++ b/tests/raylab/policy/modules/v0/test_svg_module.py @@ -3,7 +3,7 @@ import torch from ray.rllib import SampleBatch -from raylab.modules.svg_module import SVGModule +from raylab.policy.modules.v0.svg_module import SVGModule @pytest.fixture diff --git a/tests/modules/test_trpo_extensions.py b/tests/raylab/policy/modules/v0/test_trpo_extensions.py similarity index 97% rename from tests/modules/test_trpo_extensions.py rename to tests/raylab/policy/modules/v0/test_trpo_extensions.py index edd48d08..a520fd89 100644 --- a/tests/modules/test_trpo_extensions.py +++ b/tests/raylab/policy/modules/v0/test_trpo_extensions.py @@ -5,7 +5,7 @@ from gym.spaces import Box from ray.rllib import SampleBatch -from raylab.modules.catalog import TRPOTang2018 +from raylab.policy.modules.v0.trpo_tang2018 import TRPOTang2018 from .utils import make_batch from .utils import make_module diff --git a/tests/modules/utils.py b/tests/raylab/policy/modules/v0/utils.py similarity index 100% rename from tests/modules/utils.py rename to tests/raylab/policy/modules/v0/utils.py diff --git a/tests/policy/test_optimizer_collection.py b/tests/raylab/policy/test_optimizer_collection.py similarity index 100% rename from tests/policy/test_optimizer_collection.py rename to tests/raylab/policy/test_optimizer_collection.py diff --git a/tests/raylab/pytorch/__init__.py b/tests/raylab/pytorch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pytorch/conftest.py b/tests/raylab/pytorch/conftest.py similarity index 100% rename from tests/pytorch/conftest.py rename to tests/raylab/pytorch/conftest.py diff --git a/tests/raylab/pytorch/nn/__init__.py b/tests/raylab/pytorch/nn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/raylab/pytorch/nn/distributions/__init__.py b/tests/raylab/pytorch/nn/distributions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pytorch/nn/distributions/conftest.py b/tests/raylab/pytorch/nn/distributions/conftest.py similarity index 100% rename from tests/pytorch/nn/distributions/conftest.py rename to tests/raylab/pytorch/nn/distributions/conftest.py diff --git a/tests/raylab/pytorch/nn/distributions/flows/__init__.py b/tests/raylab/pytorch/nn/distributions/flows/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pytorch/nn/distributions/flows/test_affine_constant.py b/tests/raylab/pytorch/nn/distributions/flows/test_affine_constant.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/test_affine_constant.py rename to tests/raylab/pytorch/nn/distributions/flows/test_affine_constant.py diff --git a/tests/pytorch/nn/distributions/flows/test_couplings.py b/tests/raylab/pytorch/nn/distributions/flows/test_couplings.py similarity index 94% rename from tests/pytorch/nn/distributions/flows/test_couplings.py rename to tests/raylab/pytorch/nn/distributions/flows/test_couplings.py index aff975f4..69e7088a 100644 --- a/tests/pytorch/nn/distributions/flows/test_couplings.py +++ b/tests/raylab/pytorch/nn/distributions/flows/test_couplings.py @@ -2,8 +2,8 @@ import pytest import torch -from raylab.modules.networks import MLP -from raylab.modules.networks import ResidualNet +from raylab.policy.modules.networks import MLP +from raylab.policy.modules.networks import ResidualNet from raylab.pytorch.nn.distributions.flows.coupling import AdditiveCouplingTransform from raylab.pytorch.nn.distributions.flows.coupling import AffineCouplingTransform from raylab.pytorch.nn.distributions.flows.coupling import PiecewiseRQSCouplingTransform diff --git a/tests/pytorch/nn/distributions/flows/test_flow_distribution.py b/tests/raylab/pytorch/nn/distributions/flows/test_flow_distribution.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/test_flow_distribution.py rename to tests/raylab/pytorch/nn/distributions/flows/test_flow_distribution.py diff --git a/tests/pytorch/nn/distributions/flows/test_maf.py b/tests/raylab/pytorch/nn/distributions/flows/test_maf.py similarity index 100% rename from tests/pytorch/nn/distributions/flows/test_maf.py rename to tests/raylab/pytorch/nn/distributions/flows/test_maf.py diff --git a/tests/pytorch/nn/distributions/test_categorical.py b/tests/raylab/pytorch/nn/distributions/test_categorical.py similarity index 100% rename from tests/pytorch/nn/distributions/test_categorical.py rename to tests/raylab/pytorch/nn/distributions/test_categorical.py diff --git a/tests/pytorch/nn/distributions/test_normal.py b/tests/raylab/pytorch/nn/distributions/test_normal.py similarity index 100% rename from tests/pytorch/nn/distributions/test_normal.py rename to tests/raylab/pytorch/nn/distributions/test_normal.py diff --git a/tests/pytorch/nn/distributions/test_transforms.py b/tests/raylab/pytorch/nn/distributions/test_transforms.py similarity index 100% rename from tests/pytorch/nn/distributions/test_transforms.py rename to tests/raylab/pytorch/nn/distributions/test_transforms.py diff --git a/tests/pytorch/nn/distributions/test_uniform.py b/tests/raylab/pytorch/nn/distributions/test_uniform.py similarity index 100% rename from tests/pytorch/nn/distributions/test_uniform.py rename to tests/raylab/pytorch/nn/distributions/test_uniform.py diff --git a/tests/pytorch/nn/distributions/utils.py b/tests/raylab/pytorch/nn/distributions/utils.py similarity index 100% rename from tests/pytorch/nn/distributions/utils.py rename to tests/raylab/pytorch/nn/distributions/utils.py diff --git a/tests/raylab/pytorch/nn/modules/__init__.py b/tests/raylab/pytorch/nn/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/pytorch/nn/modules/test_action_output.py b/tests/raylab/pytorch/nn/modules/test_action_output.py similarity index 100% rename from tests/pytorch/nn/modules/test_action_output.py rename to tests/raylab/pytorch/nn/modules/test_action_output.py diff --git a/tests/pytorch/nn/modules/test_fully_connected.py b/tests/raylab/pytorch/nn/modules/test_fully_connected.py similarity index 100% rename from tests/pytorch/nn/modules/test_fully_connected.py rename to tests/raylab/pytorch/nn/modules/test_fully_connected.py diff --git a/tests/pytorch/nn/modules/test_gaussian_noise.py b/tests/raylab/pytorch/nn/modules/test_gaussian_noise.py similarity index 100% rename from tests/pytorch/nn/modules/test_gaussian_noise.py rename to tests/raylab/pytorch/nn/modules/test_gaussian_noise.py diff --git a/tests/pytorch/nn/modules/test_lambd.py b/tests/raylab/pytorch/nn/modules/test_lambd.py similarity index 100% rename from tests/pytorch/nn/modules/test_lambd.py rename to tests/raylab/pytorch/nn/modules/test_lambd.py diff --git a/tests/modules/test_made.py b/tests/raylab/pytorch/nn/modules/test_made.py similarity index 100% rename from tests/modules/test_made.py rename to tests/raylab/pytorch/nn/modules/test_made.py diff --git a/tests/pytorch/nn/modules/test_normalized_linear.py b/tests/raylab/pytorch/nn/modules/test_normalized_linear.py similarity index 100% rename from tests/pytorch/nn/modules/test_normalized_linear.py rename to tests/raylab/pytorch/nn/modules/test_normalized_linear.py diff --git a/tests/pytorch/nn/modules/test_tanh_squash.py b/tests/raylab/pytorch/nn/modules/test_tanh_squash.py similarity index 50% rename from tests/pytorch/nn/modules/test_tanh_squash.py rename to tests/raylab/pytorch/nn/modules/test_tanh_squash.py index b1913aef..cc10cc32 100644 --- a/tests/pytorch/nn/modules/test_tanh_squash.py +++ b/tests/raylab/pytorch/nn/modules/test_tanh_squash.py @@ -13,29 +13,39 @@ def low_high(request): @pytest.fixture -def maker(torch_script): - def factory(*args, **kwargs): - module = TanhSquash(*args, **kwargs) - return torch.jit.script(module) if torch_script else module +def squash(low_high): + low, high = low_high + return TanhSquash(low, high) + + +@pytest.fixture +def module(squash, torch_script): + if torch_script: + return torch.jit.script(squash) + return squash + - return factory +@pytest.fixture +def inputs(low_high): + low, _ = low_high + return torch.randn(10, *low.shape) -def test_squash_to_range(maker, low_high): +def test_squash_to_range(module, low_high, inputs): low, high = low_high - module = maker(low, high) - inputs = torch.randn(10, *low.shape) output = module(inputs) assert (output <= high).all() assert (output >= low).all() -def test_propagates_gradients(maker, low_high): - low, high = low_high - module = maker(low, high) - - inputs = torch.randn(10, *low.shape, requires_grad=True) +def test_propagates_gradients(module, inputs): + inputs.requires_grad_() module(inputs).mean().backward() assert inputs.grad is not None assert (inputs.grad != 0).any() + + +def test_reverse(module, inputs): + squashed = module(inputs) + assert torch.allclose(module(squashed, reverse=True), inputs, atol=1e-6) diff --git a/tests/pytorch/nn/modules/test_tril_matrix.py b/tests/raylab/pytorch/nn/modules/test_tril_matrix.py similarity index 100% rename from tests/pytorch/nn/modules/test_tril_matrix.py rename to tests/raylab/pytorch/nn/modules/test_tril_matrix.py diff --git a/tests/raylab/utils/__init__.py b/tests/raylab/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/test_replay.py b/tests/raylab/utils/test_replay.py similarity index 100% rename from tests/utils/test_replay.py rename to tests/raylab/utils/test_replay.py