Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
Expand Down Expand Up @@ -84,8 +84,9 @@ def _value(self, obs):
return self.model.value_function()[0]


A3CTorchPolicy = build_torch_policy(
A3CTorchPolicy = build_policy_class(
name="A3CTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=loss_and_entropy_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/ars/ars_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import ray
from ray.rllib.agents.es.es_torch_policy import after_init, before_init, \
make_model_and_action_dist
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.policy_template import build_policy_class

ARSTorchPolicy = build_torch_policy(
ARSTorchPolicy = build_policy_class(
name="ARSTorchPolicy",
framework="torch",
loss_fn=None,
get_default_config=lambda: ray.rllib.agents.ars.ars.DEFAULT_CONFIG,
before_init=before_init,
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/ddpg_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch, get_activation_fn
from ray.rllib.utils.framework import get_activation_fn, try_import_torch

torch, nn = try_import_torch()

Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/ddpg/ddpg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
PRIO_WEIGHTS
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import huber_loss, l2_loss

Expand Down Expand Up @@ -264,8 +264,9 @@ def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy)


DDPGTorchPolicy = build_torch_policy(
DDPGTorchPolicy = build_policy_class(
name="DDPGTorchPolicy",
framework="torch",
loss_fn=ddpg_actor_critic_loss,
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
stats_fn=build_ddpg_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/dqn/dqn_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
TorchDistributionWrapper)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -384,8 +384,9 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
return {"q_values": policy.q_values}


DQNTorchPolicy = build_torch_policy(
DQNTorchPolicy = build_policy_class(
name="DQNTorchPolicy",
framework="torch",
loss_fn=build_q_losses,
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_model_and_action_dist=build_q_model_and_distribution,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/dqn/simple_q_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDistributionWrapper
from ray.rllib.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import huber_loss
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
Expand Down Expand Up @@ -127,8 +127,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


SimpleQTorchPolicy = build_torch_policy(
SimpleQTorchPolicy = build_policy_class(
name="SimpleQPolicy",
framework="torch",
loss_fn=build_q_losses,
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
extra_action_out_fn=extra_action_out_fn,
Expand Down
9 changes: 5 additions & 4 deletions rllib/agents/dreamer/dreamer_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging

import ray
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
if torch:
Expand Down Expand Up @@ -236,8 +236,9 @@ def dreamer_optimizer_fn(policy, config):
return (model_opt, actor_opt, critic_opt)


DreamerTorchPolicy = build_torch_policy(
DreamerTorchPolicy = build_policy_class(
name="DreamerTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
action_sampler_fn=action_sampler_fn,
loss_fn=dreamer_loss,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/es/es_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
Expand Down Expand Up @@ -126,8 +126,9 @@ def make_model_and_action_dist(policy, observation_space, action_space,
return model, dist_class


ESTorchPolicy = build_torch_policy(
ESTorchPolicy = build_policy_class(
name="ESTorchPolicy",
framework="torch",
loss_fn=None,
get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG,
before_init=before_init,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
import ray.rllib.agents.impala.vtrace_torch as vtrace
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
EntropyCoeffSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
sequence_mask
Expand Down Expand Up @@ -260,8 +260,9 @@ def setup_mixins(policy, obs_space, action_space, config):
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])


VTraceTorchPolicy = build_torch_policy(
VTraceTorchPolicy = build_policy_class(
name="VTraceTorchPolicy",
framework="torch",
loss_fn=build_vtrace_loss,
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
stats_fn=stats,
Expand Down
6 changes: 3 additions & 3 deletions rllib/agents/maml/maml_tf_policy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging

import ray
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
ValueNetworkMixin
from ray.rllib.utils.framework import get_activation_fn

tf1, tf, tfv = try_import_tf()
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/maml/maml_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import ray
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
Expand Down Expand Up @@ -347,8 +347,9 @@ def setup_mixins(policy, obs_space, action_space, config):
KLCoeffMixin.__init__(policy, config)


MAMLTorchPolicy = build_torch_policy(
MAMLTorchPolicy = build_policy_class(
name="MAMLTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss,
stats_fn=maml_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/marwil/marwil_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ray
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance

Expand Down Expand Up @@ -75,8 +75,9 @@ def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy)


MARWILTorchPolicy = build_torch_policy(
MARWILTorchPolicy = build_policy_class(
name="MARWILTorchPolicy",
framework="torch",
loss_fn=marwil_loss,
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/mbmpo/mbmpo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TrainerConfigDict

Expand Down Expand Up @@ -76,8 +76,9 @@ def make_model_and_action_dist(

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
MBMPOTorchPolicy = build_torch_policy(
MBMPOTorchPolicy = build_policy_class(
name="MBMPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
make_model_and_action_dist=make_model_and_action_dist,
loss_fn=maml_loss,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/pg/pg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

Expand Down Expand Up @@ -72,8 +72,9 @@ def pg_loss_stats(policy: Policy,
# Build a child class of `TFPolicy`, given the extra options:
# - trajectory post-processing function (to calculate advantages)
# - PG loss function
PGTorchPolicy = build_torch_policy(
PGTorchPolicy = build_policy_class(
name="PGTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/ppo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from ray.rllib.models.torch.torch_action_dist import \
TorchDistributionWrapper, TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
sequence_mask
Expand Down Expand Up @@ -322,8 +322,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
AsyncPPOTorchPolicy = build_torch_policy(
AsyncPPOTorchPolicy = build_policy_class(
name="AsyncPPOTorchPolicy",
framework="torch",
loss_fn=appo_surrogate_loss,
stats_fn=stats,
postprocess_fn=postprocess_trajectory,
Expand Down
12 changes: 7 additions & 5 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
explained_variance, sequence_mask
Expand Down Expand Up @@ -111,6 +111,9 @@ def reduce_mean_valid(t):
policy._total_loss = total_loss
policy._mean_policy_loss = mean_policy_loss
policy._mean_vf_loss = mean_vf_loss
policy._vf_explained_var = explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function())
policy._mean_entropy = mean_entropy
policy._mean_kl = mean_kl

Expand All @@ -134,9 +137,7 @@ def kl_and_loss_stats(policy: Policy,
"total_loss": policy._total_loss,
"policy_loss": policy._mean_policy_loss,
"vf_loss": policy._mean_vf_loss,
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function()),
"vf_explained_var": policy._vf_explained_var,
"kl": policy._mean_kl,
"entropy": policy._mean_entropy,
"entropy_coeff": policy.entropy_coeff,
Expand Down Expand Up @@ -271,8 +272,9 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
PPOTorchPolicy = build_torch_policy(
PPOTorchPolicy = build_policy_class(
name="PPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats,
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/qmix/qmix_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(self,
return loss, mask, masked_td_error, chosen_action_qvals, targets


# TODO(sven): Make this a TorchPolicy child via `build_torch_policy`.
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
class QMixTorchPolicy(Policy):
"""QMix impl. Assumes homogeneous agents for now.

Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/sac/sac_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from ray.rllib.models.torch.torch_action_dist import \
TorchDistributionWrapper, TorchDirichlet
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -480,8 +480,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
SACTorchPolicy = build_torch_policy(
SACTorchPolicy = build_policy_class(
name="SACTorchPolicy",
framework="torch",
loss_fn=actor_critic_loss,
get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
stats_fn=stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/slateq/slateq_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
TorchDistributionWrapper)
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
TrainerConfigDict)
Expand Down Expand Up @@ -403,8 +403,9 @@ def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
return batch


SlateQTorchPolicy = build_torch_policy(
SlateQTorchPolicy = build_policy_class(
name="SlateQTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,

# build model, loss functions, and optimizers
Expand Down
Loading