Skip to content

Commit

Permalink
refactor(policy): remove clipped_double_q top-level config
Browse files Browse the repository at this point in the history
Enable Clipped Double Q-Learning by configuring double Q
in the NN module.

Signed-off-by: Ângelo Lovatto <[email protected]>
  • Loading branch information
0xangelo committed Jul 4, 2020
1 parent 8d99815 commit dd0b9c9
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 49 deletions.
5 changes: 1 addition & 4 deletions raylab/agents/mage/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
"policy_improvements": 10,
"real_data_ratio": 1,
# === MAGETorchPolicy ===
# Clipped Double Q-Learning: use the minimun of two target Q functions
# as the next action-value in the target for fitted Q iteration
"clipped_double_q": True,
# TD error regularization for MAGE loss
"lambda": 0.05,
# PyTorch optimizers to use
Expand All @@ -39,7 +36,7 @@
patience_epochs=None,
improvement_threshold=None,
).to_dict(),
"module": {"type": "MBDDPG"},
"module": {"type": "MBDDPG", "critic": {"double_q": True}},
# === Exploration Settings ===
# Default exploration behavior, iff `explore`=None is passed into
# compute_action(s).
Expand Down
2 changes: 1 addition & 1 deletion raylab/agents/mapo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"parallelize": False,
"residual": True,
},
"critic": {"double_q": True},
},
"losses": {
# Gradient estimator for optimizing expectations. Possible types include
Expand All @@ -44,7 +45,6 @@
},
# === SACTorchPolicy ===
"target_entropy": "auto",
"clipped_double_q": True,
# === TargetNetworksMixin ===
"polyak": 0.995,
# === ModelTrainingMixin ===
Expand Down
6 changes: 4 additions & 2 deletions raylab/agents/mbpo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
"encoder": {"units": (128, 128), "activation": "Swish"},
"input_dependent_scale": True,
},
"critic": {"encoder": {"units": (128, 128), "activation": "Swish"}},
"critic": {
"double_q": True,
"encoder": {"units": (128, 128), "activation": "Swish"},
},
"entropy": {"initial_alpha": 0.05},
},
"torch_optimizer": {
Expand All @@ -33,7 +36,6 @@
},
# === SACTorchPolicy ===
"target_entropy": "auto",
"clipped_double_q": True,
"polyak": 0.995,
# === ModelTrainingMixin ===
"model_training": TrainingSpec().to_dict(),
Expand Down
7 changes: 0 additions & 7 deletions raylab/agents/sac/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ def get_default_config():

return DEFAULT_CONFIG

@override(TorchPolicy)
def make_module(self, obs_space, action_space, config):
module_config = config["module"]
module_config.setdefault("critic", {})
module_config["critic"]["double_q"] = config["clipped_double_q"]
return super().make_module(obs_space, action_space, config)

@override(TorchPolicy)
def make_optimizers(self):
config = self.config["torch_optimizer"]
Expand Down
5 changes: 1 addition & 4 deletions raylab/agents/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
# If "auto", will use the heuristic provided in the SAC paper:
# H = -dim(A), where A is the action space
"target_entropy": None,
# === Twin Delayed DDPG (TD3) tricks ===
# Clipped Double Q-Learning
"clipped_double_q": True,
# === Optimization ===
# PyTorch optimizers to use
"torch_optimizer": {
Expand All @@ -28,7 +25,7 @@
# Interpolation factor in polyak averaging for target networks.
"polyak": 0.995,
# === Network ===
"module": {"type": "SAC"},
"module": {"type": "SAC", "critic": {"double_q": True}},
# === Exploration Settings ===
# Default exploration behavior, iff `explore`=None is passed into
# compute_action(s).
Expand Down
14 changes: 0 additions & 14 deletions raylab/agents/sop/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,6 @@ def get_default_config():

return DEFAULT_CONFIG

@override(TorchPolicy)
def make_module(self, obs_space, action_space, config):
module_config = config["module"]
module_config.setdefault("critic", {})
module_config["critic"]["double_q"] = config["clipped_double_q"]
module_config.setdefault("actor", {})
if (
config["exploration_config"]["type"]
== "raylab.utils.exploration.ParameterNoise"
):
module_config["actor"]["parameter_noise"] = True
# pylint:disable=no-member
return super().make_module(obs_space, action_space, config)

@override(TorchPolicy)
def make_optimizers(self):
config = self.config["torch_optimizer"]
Expand Down
9 changes: 5 additions & 4 deletions raylab/agents/sop/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
DEFAULT_CONFIG = with_base_config(
{
# === SOPTorchPolicy ===
# Clipped Double Q-Learning: use the minimun of two target Q functions
# as the next action-value in the target for fitted Q iteration
"clipped_double_q": True,
# PyTorch optimizers to use
"torch_optimizer": {
"actor": {"type": "Adam", "lr": 1e-3},
Expand All @@ -20,7 +17,11 @@
"polyak": 0.995,
# Update policy every this number of calls to `learn_on_batch`
"policy_delay": 1,
"module": {"type": "DDPG", "actor": {"separate_behavior": True}},
"module": {
"type": "DDPG",
"actor": {"separate_behavior": True},
"critic": {"double_q": True},
},
# === Exploration Settings ===
# Default exploration behavior, iff `explore`=None is passed into
# compute_action(s).
Expand Down
21 changes: 14 additions & 7 deletions tests/raylab/agents/sac/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@

@pytest.fixture(params=(True, False))
def input_dependent_scale(request):
return {"module": {"actor": {"input_dependent_scale": request.param}}}
return request.param


@pytest.fixture(params=(True, False))
def clipped_double_q(request):
return {"clipped_double_q": request.param}


def test_actor_loss(policy_and_batch_fn, clipped_double_q, input_dependent_scale):
policy, batch = policy_and_batch_fn({**clipped_double_q, **input_dependent_scale})
def double_q(request):
return request.param


def test_actor_loss(policy_and_batch_fn, double_q, input_dependent_scale):
policy, batch = policy_and_batch_fn(
{
"module": {
"actor": {"input_dependent_scale": input_dependent_scale},
"critic": {"double_q": double_q},
}
}
)
loss, info = policy.loss_actor(batch)

assert loss.shape == ()
Expand Down
6 changes: 3 additions & 3 deletions tests/raylab/agents/sac/test_critics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@


@pytest.fixture(params=(True, False))
def clipped_double_q(request):
def double_q(request):
return request.param


@pytest.fixture
def policy_and_batch(policy_and_batch_fn, clipped_double_q):
config = {"clipped_double_q": clipped_double_q, "polyak": 0.5}
def policy_and_batch(policy_and_batch_fn, double_q):
config = {"module": {"critic": {"double_q": double_q}}, "polyak": 0.5}
return policy_and_batch_fn(config)


Expand Down
6 changes: 3 additions & 3 deletions tests/raylab/agents/sop/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@


@pytest.fixture(params=(True, False))
def clipped_double_q(request):
def double_q(request):
return request.param


@pytest.fixture
def config(clipped_double_q):
return {"clipped_double_q": clipped_double_q, "policy_delay": 2}
def config(double_q):
return {"module": {"critic": {"double_q": double_q}}, "policy_delay": 2}


@pytest.fixture
Expand Down

0 comments on commit dd0b9c9

Please sign in to comment.