Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Deprecate interaction_mode #1067

Merged
merged 7 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Besides those compute parameters, users may choose to configure the following pa
- reset_at_each_iter: if :obj:`True`, the environment(s) will be reset after each batch collection
- split_trajs: if :obj:`True`, the trajectories will be split and delivered in a padded tensordict
along with a :obj:`"mask"` key that will point to a boolean mask representing the valid values.
- exploration_mode: the exploration strategy to be used with the policy.
- exploration_type: the exploration strategy to be used with the policy.
- reset_when_done: whether environments should be reset when reaching a done state.


Expand Down
4 changes: 2 additions & 2 deletions examples/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.cuda
from hydra.core.config_store import ConfigStore
from torchrl.envs.transforms import RewardScaling
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives.value import TD0Estimator
from torchrl.record.loggers import generate_exp_name, get_logger
from torchrl.trainers.helpers.collectors import (
Expand Down Expand Up @@ -95,7 +95,7 @@ def main(cfg: "DictConfig"): # noqa: F821

loss_module = make_a2c_loss(model, cfg)
if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = model(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
4 changes: 2 additions & 2 deletions examples/bandits/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import DistributionalQValueActor, EGreedyWrapper, MLP, QValueActor
from torchrl.objectives import DistributionalDQNLoss, DQNLoss

Expand Down Expand Up @@ -94,7 +94,7 @@
init_r = None
init_loss = None
for i in pbar:
with set_exploration_mode("random"):
with set_exploration_type(ExplorationType.RANDOM):
data = env.step(policy(env.reset()))
loss_vals = loss(data)
loss_val = sum(
Expand Down
4 changes: 2 additions & 2 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -122,7 +122,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_model_explore.share_memory()

if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
15 changes: 8 additions & 7 deletions examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.cuda
import tqdm
from tensordict.nn import InteractionType

from torch import nn, optim
from torchrl.collectors import SyncDataCollector
Expand All @@ -22,7 +23,7 @@
from torchrl.envs import EnvCreator, ParallelEnv

from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, SafeModule
from torchrl.modules.distributions import OneHotCategorical

Expand Down Expand Up @@ -134,7 +135,7 @@ def env_factory(num_workers):
out_keys=["action"],
distribution_class=OneHotCategorical,
distribution_kwargs={},
default_interaction_mode="random",
default_interaction_type=InteractionType.RANDOM,
return_log_prob=False,
).to(device)

Expand Down Expand Up @@ -224,7 +225,7 @@ def env_factory(num_workers):
new_collected_epochs = len(np.unique(tensordict["collector"]["traj_ids"]))
if r0 is None:
r0 = (
tensordict["reward"].sum().item()
tensordict["next", "reward"].sum().item()
/ new_collected_epochs
/ cfg.env_per_collector
)
Expand Down Expand Up @@ -284,7 +285,7 @@ def env_factory(num_workers):
rewards.append(
(
i,
tensordict["reward"].sum().item()
tensordict["next", "reward"].sum().item()
/ cfg.env_per_collector
/ new_collected_epochs,
)
Expand All @@ -307,16 +308,16 @@ def env_factory(num_workers):
}
)

with set_exploration_mode(
"random"
with set_exploration_type(
ExplorationType.RANDOM
), torch.no_grad(): # TODO: exploration mode to mean causes nans

eval_rollout = test_env.rollout(
max_steps=cfg.max_frames_per_traj,
policy=actor,
auto_cast_to_device=True,
).clone()
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
metrics.update({"test_reward": rewards_eval[-1][1]})
Expand Down
8 changes: 5 additions & 3 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
optim.step()
optim.zero_grad()

logs["reward"].append(tensordict_data["reward"].mean().item())
logs["reward"].append(tensordict_data["next", "reward"].mean().item())
pbar.update(tensordict_data.numel() * frame_skip)
cum_reward_str = f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
logs["step_count"].append(tensordict_data["step_count"].max().item())
Expand All @@ -206,8 +206,10 @@
with set_exploration_mode("mean"), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["reward"].mean().item())
logs["eval reward (sum)"].append(eval_rollout["reward"].sum().item())
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
logs["eval reward (sum)"].append(
eval_rollout["next", "reward"].sum().item()
)
logs["eval step_count"].append(eval_rollout["step_count"].max().item())
eval_str = f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} (init: {logs['eval reward (sum)'][0]: 4.4f}), eval step-count: {logs['eval step_count'][-1]}"
del eval_rollout
Expand Down
2 changes: 1 addition & 1 deletion examples/dreamer/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def main(cfg: "DictConfig"): # noqa: F821
replay_buffer.extend(tensordict.cpu())
logger.log_scalar(
"r_training",
tensordict["reward"].mean().detach().item(),
tensordict["next", "reward"].mean().detach().item(),
step=collected_frames,
)

Expand Down
14 changes: 8 additions & 6 deletions examples/iql/iql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torchrl.data.replay_buffers.storages import LazyMemmapStorage
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
from torchrl.modules.distributions import TanhNormal

Expand Down Expand Up @@ -147,7 +147,7 @@ def env_factory(num_workers):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
default_interaction_type=ExplorationType.RANDOM,
return_log_prob=False,
)

Expand Down Expand Up @@ -247,7 +247,7 @@ def env_factory(num_workers):
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["reward"].sum(-1).mean().item()
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

if "mask" in tensordict.keys():
Expand Down Expand Up @@ -293,7 +293,9 @@ def env_factory(num_workers):
if cfg.prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append((i, tensordict["reward"].sum().item() / cfg.env_per_collector))
rewards.append(
(i, tensordict["next", "reward"].sum().item() / cfg.env_per_collector)
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
Expand All @@ -309,13 +311,13 @@ def env_factory(num_workers):
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)

with set_exploration_mode("mean"), torch.no_grad():
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
eval_rollout = test_env.rollout(
max_steps=cfg.max_frames_per_traj,
policy=model[0],
auto_cast_to_device=True,
).clone()
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames)
Expand Down
4 changes: 2 additions & 2 deletions examples/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives.value import GAE
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -98,7 +98,7 @@ def main(cfg: "DictConfig"): # noqa: F821

loss_module = make_ppo_loss(model, cfg)
if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = model(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
4 changes: 2 additions & 2 deletions examples/redq/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -134,7 +134,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_model_explore.share_memory()

if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
4 changes: 2 additions & 2 deletions examples/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hydra.core.config_store import ConfigStore
from torchrl.envs import EnvCreator, ParallelEnv
from torchrl.envs.transforms import RewardScaling, TransformedEnv
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import OrnsteinUhlenbeckProcessWrapper
from torchrl.record import VideoRecorder
from torchrl.record.loggers import generate_exp_name, get_logger
Expand Down Expand Up @@ -132,7 +132,7 @@ def main(cfg: "DictConfig"): # noqa: F821
actor_model_explore.share_memory()

if cfg.gSDE:
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
# get dimensions to build the parallel env
proof_td = actor_model_explore(proof_env.reset().to(device))
action_dim_gsde, state_dim_gsde = proof_td.get("_eps_gSDE").shape[-2:]
Expand Down
17 changes: 10 additions & 7 deletions examples/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.cuda
import tqdm
from tensordict.nn import InteractionType

from torch import nn, optim
from torchrl.collectors import MultiSyncDataCollector
Expand All @@ -26,7 +27,7 @@
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.transforms import RewardScaling
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
AdditiveGaussianWrapper,
MLP,
Expand Down Expand Up @@ -168,7 +169,7 @@ def main(cfg: "DictConfig"): # noqa: F821
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
default_interaction_type=InteractionType.RANDOM,
return_log_prob=False,
)

Expand All @@ -191,7 +192,7 @@ def main(cfg: "DictConfig"): # noqa: F821
model = nn.ModuleList([actor, qvalue]).to(device)

# init nets
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = eval_env.reset()
td = td.to(device)
for net in model:
Expand Down Expand Up @@ -272,7 +273,7 @@ def main(cfg: "DictConfig"): # noqa: F821
collector.update_policy_weights_()

if r0 is None:
r0 = tensordict["reward"].sum(-1).mean().item()
r0 = tensordict["next", "reward"].sum(-1).mean().item()
pbar.update(tensordict.numel())

# extend the replay buffer with the new data
Expand Down Expand Up @@ -321,7 +322,9 @@ def main(cfg: "DictConfig"): # noqa: F821
if cfg.prb:
replay_buffer.update_priority(sampled_tensordict)

rewards.append((i, tensordict["reward"].sum().item() / cfg.env_per_collector))
rewards.append(
(i, tensordict["next", "reward"].sum().item() / cfg.env_per_collector)
)
train_log = {
"train_reward": rewards[-1][1],
"collected_frames": collected_frames,
Expand All @@ -336,13 +339,13 @@ def main(cfg: "DictConfig"): # noqa: F821
for key, value in train_log.items():
logger.log_scalar(key, value, step=collected_frames)

with set_exploration_mode("mean"), torch.no_grad():
with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
eval_rollout = eval_env.rollout(
cfg.max_frames_per_traj // cfg.frame_skip,
actor_model_explore,
auto_cast_to_device=True,
)
eval_reward = eval_rollout["reward"].sum(-2).mean().item()
eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item()
rewards_eval.append((i, eval_reward))
eval_str = f"eval cumulative reward: {rewards_eval[-1][1]: 4.4f} (init: {rewards_eval[0][1]: 4.4f})"
logger.log_scalar("test_reward", rewards_eval[-1][1], step=collected_frames)
Expand Down
3 changes: 2 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy

from packaging import version as pack_version
from tensordict.nn import InteractionType

_has_functorch = True
try:
Expand Down Expand Up @@ -3112,7 +3113,7 @@ def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
SafeProbabilisticModule(
in_keys=["loc", "scale"],
out_keys="action",
default_interaction_mode="random",
default_interaction_type=InteractionType.RANDOM,
distribution_class=TanhNormal,
),
)
Expand Down
Loading