Skip to content

Commit

Permalink
[BugFix] Fix multiple context syntax in multiagent examples (#1943)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Feb 21, 2024
1 parent 13bef42 commit 23bf315
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion examples/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad() and set_exploration_type(ExplorationType.MEAN):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad() and set_exploration_type(ExplorationType.MEAN):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
3 changes: 2 additions & 1 deletion examples/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import ClipPPOLoss, ValueEstimators
Expand Down Expand Up @@ -235,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad():
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/qmix_vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad() and set_exploration_type(ExplorationType.MEAN):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821
and cfg.logger.backend
):
evaluation_start = time.time()
with torch.no_grad() and set_exploration_type(ExplorationType.MODE):
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
env_test.frames = []
rollouts = env_test.rollout(
max_steps=cfg.env.max_steps,
Expand Down

0 comments on commit 23bf315

Please sign in to comment.