Skip to content

Commit

Permalink
[Environment] MAgent2 (#137)
Browse files Browse the repository at this point in the history
* added MAgent2 task

* unset don on any

* dependencies moved to  function

* tasks with dead agents eliminated

* unused references removed

* Update benchmarl/environments/magent/common.py

* empty

* rename schema

* rename schema

* Revert "rename schema"

This reverts commit 65de134.

* ci

* ci

* docs

* empty

* fixes

* fixes

* fixes

* try less tests

* compsitespec -> composite

* all tests apart reloading

* all tests apart reloading

* add reloading test

* more docs on install

* tests

---------

Co-authored-by: Matteo Bettini <[email protected]>
Co-authored-by: Matteo Bettini <[email protected]>
  • Loading branch information
3 people authored Nov 29, 2024
1 parent 3d15312 commit c78eb89
Show file tree
Hide file tree
Showing 13 changed files with 426 additions and 8 deletions.
6 changes: 6 additions & 0 deletions .github/unittest/install_magent2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@


pip install git+https://github.com/Farama-Foundation/MAgent2

sudo apt-get update
sudo apt-get install python3-opengl xvfb
43 changes: 43 additions & 0 deletions .github/workflows/magent_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see:
# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions


name: magent_tests

on:
push:
branches: [ $default-branch , "main" ]
pull_request:
branches: [ $default-branch , "main" ]

permissions:
contents: read

jobs:
tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
bash .github/unittest/install_dependencies_nightly.sh
- name: Install magent2
run: |
bash .github/unittest/install_magent2.sh
- name: Test with pytest
run: |
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_magent.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: false
22 changes: 15 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ pip install "pettingzoo[all]"
```bash
pip install dm-meltingpot
```

##### MAgent2

```bash
pip install git+https://github.com/Farama-Foundation/MAgent
```

##### SMACv2

Follow the instructions on the environment [repository](https://github.com/oxwhirl/smacv2).
Expand Down Expand Up @@ -239,13 +246,14 @@ determine the training strategy. Here is a table with the currently implemented
challenge to solve.
They differ based on many aspects, here is a table with the current environments in BenchMARL

| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized |
|--------------------------------------------------------------------|--------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:|
| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [27](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes |
| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No |
| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No |
| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No |
| [MeltingPot](https://github.com/google-deepmind/meltingpot) | [49](benchmarl/conf/task/meltingpot) | Cooperative + Competitive | Yes | Independent | Discrete | No |
| Environment | Tasks | Cooperation | Global state | Reward function | Action space | Vectorized |
|---------------------------------------------------------------------|--------------------------------------|---------------------------|--------------|-------------------------------|-----------------------|:----------------:|
| [VMAS](https://github.com/proroklab/VectorizedMultiAgentSimulator) | [27](benchmarl/conf/task/vmas) | Cooperative + Competitive | No | Shared + Independent + Global | Continuous + Discrete | Yes |
| [SMACv2](https://github.com/oxwhirl/smacv2) | [15](benchmarl/conf/task/smacv2) | Cooperative | Yes | Global | Discrete | No |
| [MPE](https://github.com/openai/multiagent-particle-envs) | [8](benchmarl/conf/task/pettingzoo) | Cooperative + Competitive | Yes | Shared + Independent | Continuous + Discrete | No |
| [SISL](https://github.com/sisl/MADRL) | [2](benchmarl/conf/task/pettingzoo) | Cooperative | No | Shared | Continuous | No |
| [MeltingPot](https://github.com/google-deepmind/meltingpot) | [49](benchmarl/conf/task/meltingpot) | Cooperative + Competitive | Yes | Independent | Discrete | No |
| [MAgent2](https://github.com/Farama-Foundation/magent2) | [1](benchmarl/conf/task/magent) | Cooperative + Competitive | Yes | Global in groups | Discrete | No |


> [!NOTE]
Expand Down
9 changes: 9 additions & 0 deletions benchmarl/conf/task/magent/adversarial_pursuit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- magent_adversarial_pursuit_config
- _self_

map_size: 45
minimap_mode: False
tag_penalty: -0.2
max_cycles: 500
extra_features: False
4 changes: 3 additions & 1 deletion benchmarl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
#

from .common import _get_task_config_class, Task

from .magent.common import MAgentTask
from .meltingpot.common import MeltingPotTask
from .pettingzoo.common import PettingZooTask
from .smacv2.common import Smacv2Task
from .vmas.common import VmasTask

# The enum classes for the environments available.
# This is the only object in this file you need to modify when adding a new environment.
tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]
tasks = [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask, MAgentTask]

# This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum
# It is used by automatically load task enums from yaml files.
Expand Down
Empty file.
16 changes: 16 additions & 0 deletions benchmarl/environments/magent/adversarial_pursuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from dataclasses import dataclass, MISSING


@dataclass
class TaskConfig:
map_size: int = MISSING
minimap_mode: bool = MISSING
tag_penalty: float = MISSING
max_cycles: int = MISSING
extra_features: bool = MISSING
131 changes: 131 additions & 0 deletions benchmarl/environments/magent/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
from torchrl.envs import EnvBase, PettingZooWrapper

from benchmarl.environments.common import Task

from benchmarl.utils import DEVICE_TYPING


class MAgentTask(Task):
"""Enum for MAgent2 tasks."""

ADVERSARIAL_PURSUIT = None
# BATTLE = None
# BATTLEFIELD = None
# COMBINED_ARMS = None
# GATHER = None
# TIGER_DEER = None

def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:

return lambda: PettingZooWrapper(
env=self.__get_env(),
return_state=True,
seed=seed,
done_on_any=False,
use_mask=False,
device=device,
)

def __get_env(self) -> EnvBase:
try:
from magent2.environments import (
adversarial_pursuit_v4,
# battle_v4,
# battlefield_v5,
# combined_arms_v6,
# gather_v5,
# tiger_deer_v4
)
except ImportError:
raise ImportError(
"Module `magent2` not found, install it using `pip install magent2`"
)

envs = {
"ADVERSARIAL_PURSUIT": adversarial_pursuit_v4,
# "BATTLE": battle_v4,
# "BATTLEFIELD": battlefield_v5,
# "COMBINED_ARMS": combined_arms_v6,
# "GATHER": gather_v5,
# "TIGER_DEER": tiger_deer_v4
}
if self.name not in envs:
raise Exception(f"{self.name} is not an environment of MAgent2")
return envs[self.name].parallel_env(**self.config, render_mode="rgb_array")

def supports_continuous_actions(self) -> bool:
return False

def supports_discrete_actions(self) -> bool:
return True

def has_state(self) -> bool:
return True

def has_render(self, env: EnvBase) -> bool:
return True

def max_steps(self, env: EnvBase) -> int:
return self.config["max_cycles"]

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return env.group_map

def state_spec(self, env: EnvBase) -> Optional[Composite]:
return Composite({"state": env.observation_spec["state"].clone()})

def action_mask_spec(self, env: EnvBase) -> Optional[Composite]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "action_mask":
del group_obs_spec[key]
if group_obs_spec.is_empty():
del observation_spec[group]
del observation_spec["state"]
if observation_spec.is_empty():
return None
return observation_spec

def observation_spec(self, env: EnvBase) -> Composite:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "observation":
del group_obs_spec[key]
del observation_spec["state"]
return observation_spec

def info_spec(self, env: EnvBase) -> Optional[Composite]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "info":
del group_obs_spec[key]
del observation_spec["state"]
return observation_spec

def action_spec(self, env: EnvBase) -> Composite:
return env.full_action_spec

@staticmethod
def env_name() -> str:
return "magent"
2 changes: 2 additions & 0 deletions docs/source/concepts/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ They differ based on many aspects, here is a table with the current environments
+-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+
| :class:`~benchmarl.environments.MeltingPotTask` | 49 | Cooperative + Competitive | Yes | Independent | Discrete | No |
+-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+
| :class:`~benchmarl.environments.MAgentTask` | 1 | Cooperative + Competitive | Yes | Global in groups | Discrete | No |
+-------------------------------------------------+-------+---------------------------+--------------+-------------------------------+-----------------------+------------+



Expand Down
10 changes: 10 additions & 0 deletions docs/source/usage/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ Follow the instructions on the environment `repository <https://github.com/oxwhi
`Here <https://github.com/facebookresearch/BenchMARL/blob/main/.github/unittest/install_smacv2.sh>`_
is how we install it on linux.

MAgent2
^^^^^^^
:github:`null` `GitHub <https://github.com/Farama-Foundation/MAgent>`__


.. code-block:: console
pip install git+https://github.com/Farama-Foundation/MAgent
Install models
--------------

Expand Down
54 changes: 54 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,31 @@ def mlp_gnn_sequence_config() -> ModelConfig:
)


@pytest.fixture
def cnn_gnn_sequence_config() -> ModelConfig:
return SequenceModelConfig(
model_configs=[
CnnConfig(
cnn_num_cells=[4, 3],
cnn_kernel_sizes=[3, 2],
cnn_strides=1,
cnn_paddings=0,
cnn_activation_class=nn.Tanh,
mlp_num_cells=[4],
mlp_activation_class=nn.Tanh,
mlp_layer_class=nn.Linear,
),
GnnConfig(
topology="full",
self_loops=False,
gnn_class=torch_geometric.nn.conv.GATv2Conv,
),
MlpConfig(num_cells=[4], activation_class=nn.Tanh, layer_class=nn.Linear),
],
intermediate_sizes=[5, 3],
)


@pytest.fixture
def gru_mlp_sequence_config() -> ModelConfig:
return SequenceModelConfig(
Expand Down Expand Up @@ -128,3 +153,32 @@ def lstm_mlp_sequence_config() -> ModelConfig:
],
intermediate_sizes=[5],
)


@pytest.fixture
def cnn_lstm_sequence_config() -> ModelConfig:
return SequenceModelConfig(
model_configs=[
CnnConfig(
cnn_num_cells=[4, 3],
cnn_kernel_sizes=[3, 2],
cnn_strides=1,
cnn_paddings=0,
cnn_activation_class=nn.Tanh,
mlp_num_cells=[4],
mlp_activation_class=nn.Tanh,
mlp_layer_class=nn.Linear,
),
LstmConfig(
hidden_size=13,
mlp_num_cells=[],
mlp_activation_class=nn.Tanh,
mlp_layer_class=nn.Linear,
n_layers=1,
bias=True,
dropout=0,
compile=False,
),
],
intermediate_sizes=[5],
)
Loading

0 comments on commit c78eb89

Please sign in to comment.