Skip to content

Commit

Permalink
Merge branch 'release/0.8.5'
Browse files Browse the repository at this point in the history
  • Loading branch information
0xangelo committed Jun 25, 2020
2 parents 49c6564 + b1b2cb9 commit ec2d4b8
Show file tree
Hide file tree
Showing 14 changed files with 275 additions and 86 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/poetry-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
name: Poetry publish

on:
push:
branches: master
push:
tags:
- 'v*.*.*'

Expand Down
2 changes: 1 addition & 1 deletion examples/MAPO/swingup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_config():
# model-aware deterministic policy gradient
"model_samples": 1,
# Whether to use the environment's true model to sample states
"true_model": True,
"true_model": False,
},
# PyTorch optimizers to use
"torch_optimizer": {
Expand Down
118 changes: 117 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "raylab"
version = "0.8.4"
version = "0.8.5"
description = "Reinforcement learning algorithms in RLlib and PyTorch."
authors = ["Ângelo Gregório Lovatto <[email protected]>"]
license = "MIT"
Expand Down Expand Up @@ -39,6 +39,9 @@ mypy = "^0.782"
coverage = "^5.1"
ipython = "^7.15.0"
poetry-version = "^0.1.5"
pytest-mock = "^3.1.1"
pytest-sugar = "^0.9.3"
auto-changelog = "^0.5.1"

[tool.poetry.scripts]
raylab = "raylab.cli:raylab"
Expand Down
14 changes: 2 additions & 12 deletions raylab/agents/mage/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,10 @@ def get_default_config():

return DEFAULT_CONFIG

def set_reward_from_config(self, *args, **kwargs):
super().set_reward_from_config(*args, **kwargs)
def _set_reward_hook(self):
self.loss_critic.set_reward_fn(self.reward_fn)

def set_reward_from_callable(self, *args, **kwargs):
super().set_reward_from_callable(*args, **kwargs)
self.loss_critic.set_reward_fn(self.reward_fn)

def set_termination_from_config(self, *args, **kwargs):
super().set_termination_from_config(*args, **kwargs)
self.loss_critic.set_termination_fn(self.termination_fn)

def set_termination_from_callable(self, *args, **kwargs):
super().set_termination_from_callable(*args, **kwargs)
def _set_termination_hook(self):
self.loss_critic.set_termination_fn(self.termination_fn)

def make_optimizers(self):
Expand Down
21 changes: 3 additions & 18 deletions raylab/agents/mapo/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,17 @@ def __init__(self, observation_space, action_space, config):
self.loss_actor.grad_estimator = self.config["losses"]["grad_estimator"]

@override(EnvFnMixin)
def set_reward_from_config(self, *args, **kwargs):
super().set_reward_from_config(*args, **kwargs)
def _set_reward_hook(self):
self.loss_model.set_reward_fn(self.reward_fn)
self.loss_actor.set_reward_fn(self.reward_fn)

@override(EnvFnMixin)
def set_termination_from_config(self, *args, **kwargs):
super().set_termination_from_config(*args, **kwargs)
def _set_termination_hook(self):
self.loss_model.set_termination_fn(self.termination_fn)
self.loss_actor.set_termination_fn(self.termination_fn)

@override(EnvFnMixin)
def set_reward_from_callable(self, *args, **kwargs):
super().set_reward_from_callable(*args, **kwargs)
self.loss_model.set_reward_fn(self.reward_fn)
self.loss_actor.set_reward_fn(self.reward_fn)

@override(EnvFnMixin)
def set_termination_from_callable(self, *args, **kwargs):
super().set_termination_from_callable(*args, **kwargs)
self.loss_model.set_termination_fn(self.termination_fn)
self.loss_actor.set_termination_fn(self.termination_fn)

@override(EnvFnMixin)
def set_dynamics_from_callable(self, *args, **kwargs):
super().set_dynamics_from_callable(*args, **kwargs)
def _set_dynamics_hook(self):
self.loss_actor = DAPO(self.dynamics_fn, self.module.actor, self.module.critics)
self.loss_actor.gamma = self.config["gamma"]
self.loss_actor.dynamics_samples = self.config["losses"]["model_samples"]
Expand Down
3 changes: 1 addition & 2 deletions raylab/agents/svg/inf/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def __init__(self, observation_space, action_space, config):
)

@override(EnvFnMixin)
def set_reward_from_config(self, env_name: str, env_config: dict):
super().set_reward_from_config(env_name, env_config)
def _set_reward_hook(self):
self.loss_actor.set_reward_fn(self.reward_fn)

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions raylab/agents/svg/one/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def __init__(self, *args, **kwargs):
self.loss_actor.gamma = self.config["gamma"]

@override(EnvFnMixin)
def set_reward_from_config(self, env_name: str, env_config: dict):
super().set_reward_from_config(env_name, env_config)
def _set_reward_hook(self):
self.loss_actor.set_reward_fn(self.reward_fn)

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions raylab/agents/svg/soft/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def __init__(self, observation_space, action_space, config):
)

@override(EnvFnMixin)
def set_reward_from_config(self, env_name: str, env_config: dict):
super().set_reward_from_config(env_name, env_config)
def _set_reward_hook(self):
self.loss_actor.set_reward_fn(self.reward_fn)

@staticmethod
Expand Down
Loading

0 comments on commit ec2d4b8

Please sign in to comment.