Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Dec 17, 2024
1 parent d2b242f commit a6952e0
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 59 deletions.
65 changes: 7 additions & 58 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,14 @@
from dataclasses import dataclass, MISSING
from pathlib import Path

from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional

import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential
from torchrl.collectors import SyncDataCollector

from torchrl.data import Composite
from torchrl.envs import (
EnvBase,
InitTracker,
ParallelEnv,
SerialEnv,
TensorDictPrimer,
TransformedEnv,
)
from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
Expand All @@ -44,7 +36,7 @@
from benchmarl.experiment.logger import Logger
from benchmarl.models import GnnConfig, SequenceModelConfig
from benchmarl.models.common import ModelConfig
from benchmarl.utils import _read_yaml_config, seed_everything
from benchmarl.utils import _add_rnn_transforms, _read_yaml_config, seed_everything

_has_hydra = importlib.util.find_spec("hydra") is not None
if _has_hydra:
Expand Down Expand Up @@ -461,8 +453,10 @@ def _setup_task(self):

# Add rnn transforms here so they do not show in the benchmarl specs
if self.model_config.is_rnn:
self.test_env = self._add_rnn_transforms(lambda: self.test_env)()
env_func = self._add_rnn_transforms(env_func)
self.test_env = _add_rnn_transforms(
lambda: self.test_env, self.group_map, self.model_config
)()
env_func = _add_rnn_transforms(env_func, self.group_map, self.model_config)

# Initialize train env
if self.test_env.batch_size == ():
Expand Down Expand Up @@ -949,48 +943,3 @@ def _load_experiment(self) -> Experiment:
)
self.load_state_dict(loaded_dict)
return self

def _add_rnn_transforms(
self,
env_fun: Callable[[], EnvBase],
) -> Callable[[], EnvBase]:
"""
This function adds RNN specific transforms to the environment
Args:
env_fun (callable): a function that takes no args and creates an environment
Returns: a function that takes no args and creates an environment
"""

def model_fun():
env = env_fun()
group_map = self.task.group_map(env)
spec_actor = self.model_config.get_model_state_spec()
spec_actor = Composite(
{
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in group_map.items()
}
)

out_env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor, reset_key="_reset")]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return out_env

return model_fun
56 changes: 55 additions & 1 deletion benchmarl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

import importlib
import random
from typing import Any, Dict, Union
import typing
from typing import Any, Callable, Dict, List, Union

import torch
import yaml
from torchrl.data import Composite
from torchrl.envs import Compose, EnvBase, InitTracker, TensorDictPrimer, TransformedEnv

if typing.TYPE_CHECKING:
from benchmarl.models import ModelConfig

_has_numpy = importlib.util.find_spec("numpy") is not None

Expand Down Expand Up @@ -53,3 +59,51 @@ def seed_everything(seed: int):
import numpy

numpy.random.seed(seed)


def _add_rnn_transforms(
env_fun: Callable[[], EnvBase],
group_map: Dict[str, List[str]],
model_config: "ModelConfig",
) -> Callable[[], EnvBase]:
"""
This function adds RNN specific transforms to the environment
Args:
env_fun (callable): a function that takes no args and creates an environment
group_map (Dict[str,List[str]]): the group_map of the agents
model_config (ModelConfig): the model configuration
Returns: a function that takes no args and creates an environment
"""

def model_fun():
env = env_fun()
spec_actor = model_config.get_model_state_spec()
spec_actor = Composite(
{
group: Composite(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in group_map.items()
}
)

out_env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor, reset_key="_reset")]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return out_env

return model_fun
28 changes: 28 additions & 0 deletions test/test_meltingpot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from benchmarl.algorithms import (
algorithm_config_registry,
IppoConfig,
MappoConfig,
MasacConfig,
QmixConfig,
)
Expand Down Expand Up @@ -78,6 +79,33 @@ def test_all_tasks(
)
experiment.run()

@pytest.mark.parametrize("algo_config", [MappoConfig])
@pytest.mark.parametrize("task", [MeltingPotTask.COINS])
@pytest.mark.parametrize("parallel_collection", [True, False])
def test_lstm(
self,
algo_config: AlgorithmConfig,
task: Task,
parallel_collection: bool,
experiment_config,
cnn_lstm_sequence_config,
):
algo_config = algo_config.get_from_yaml()
if algo_config.has_critic():
algo_config.share_param_critic = False
experiment_config.parallel_collection = parallel_collection
experiment_config.share_policy_params = False
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config,
model_config=cnn_lstm_sequence_config,
critic_model_config=cnn_lstm_sequence_config,
seed=0,
config=experiment_config,
task=task,
)
experiment.run()

@pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
@pytest.mark.parametrize("task", [MeltingPotTask.COMMONS_HARVEST__OPEN])
def test_reloading_trainer(
Expand Down

0 comments on commit a6952e0

Please sign in to comment.