-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added MAgent2 task * unset don on any * dependencies moved to function * tasks with dead agents eliminated * unused references removed * Update benchmarl/environments/magent/common.py * empty * rename schema * rename schema * Revert "rename schema" This reverts commit 65de134. * ci * ci * docs * empty * fixes * fixes * fixes * try less tests * compsitespec -> composite * all tests apart reloading * all tests apart reloading * add reloading test * more docs on install * tests --------- Co-authored-by: Matteo Bettini <[email protected]> Co-authored-by: Matteo Bettini <[email protected]>
- Loading branch information
1 parent
3d15312
commit c78eb89
Showing
13 changed files
with
426 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
|
||
|
||
pip install git+https://github.com/Farama-Foundation/MAgent2 | ||
|
||
sudo apt-get update | ||
sudo apt-get install python3-opengl xvfb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: | ||
# https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
|
||
name: magent_tests | ||
|
||
on: | ||
push: | ||
branches: [ $default-branch , "main" ] | ||
pull_request: | ||
branches: [ $default-branch , "main" ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
tests: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.11"] | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies | ||
run: | | ||
bash .github/unittest/install_dependencies_nightly.sh | ||
- name: Install magent2 | ||
run: | | ||
bash .github/unittest/install_magent2.sh | ||
- name: Test with pytest | ||
run: | | ||
xvfb-run -s "-screen 0 1024x768x24" pytest test/test_magent.py --doctest-modules --junitxml=junit/test-results.xml --cov=. --cov-report=xml --cov-report=html | ||
- name: Upload coverage to Codecov | ||
uses: codecov/codecov-action@v3 | ||
with: | ||
fail_ci_if_error: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
defaults: | ||
- magent_adversarial_pursuit_config | ||
- _self_ | ||
|
||
map_size: 45 | ||
minimap_mode: False | ||
tag_penalty: -0.2 | ||
max_cycles: 500 | ||
extra_features: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
from dataclasses import dataclass, MISSING | ||
|
||
|
||
@dataclass | ||
class TaskConfig: | ||
map_size: int = MISSING | ||
minimap_mode: bool = MISSING | ||
tag_penalty: float = MISSING | ||
max_cycles: int = MISSING | ||
extra_features: bool = MISSING |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
|
||
from typing import Callable, Dict, List, Optional | ||
|
||
from torchrl.data import Composite | ||
from torchrl.envs import EnvBase, PettingZooWrapper | ||
|
||
from benchmarl.environments.common import Task | ||
|
||
from benchmarl.utils import DEVICE_TYPING | ||
|
||
|
||
class MAgentTask(Task): | ||
"""Enum for MAgent2 tasks.""" | ||
|
||
ADVERSARIAL_PURSUIT = None | ||
# BATTLE = None | ||
# BATTLEFIELD = None | ||
# COMBINED_ARMS = None | ||
# GATHER = None | ||
# TIGER_DEER = None | ||
|
||
def get_env_fun( | ||
self, | ||
num_envs: int, | ||
continuous_actions: bool, | ||
seed: Optional[int], | ||
device: DEVICE_TYPING, | ||
) -> Callable[[], EnvBase]: | ||
|
||
return lambda: PettingZooWrapper( | ||
env=self.__get_env(), | ||
return_state=True, | ||
seed=seed, | ||
done_on_any=False, | ||
use_mask=False, | ||
device=device, | ||
) | ||
|
||
def __get_env(self) -> EnvBase: | ||
try: | ||
from magent2.environments import ( | ||
adversarial_pursuit_v4, | ||
# battle_v4, | ||
# battlefield_v5, | ||
# combined_arms_v6, | ||
# gather_v5, | ||
# tiger_deer_v4 | ||
) | ||
except ImportError: | ||
raise ImportError( | ||
"Module `magent2` not found, install it using `pip install magent2`" | ||
) | ||
|
||
envs = { | ||
"ADVERSARIAL_PURSUIT": adversarial_pursuit_v4, | ||
# "BATTLE": battle_v4, | ||
# "BATTLEFIELD": battlefield_v5, | ||
# "COMBINED_ARMS": combined_arms_v6, | ||
# "GATHER": gather_v5, | ||
# "TIGER_DEER": tiger_deer_v4 | ||
} | ||
if self.name not in envs: | ||
raise Exception(f"{self.name} is not an environment of MAgent2") | ||
return envs[self.name].parallel_env(**self.config, render_mode="rgb_array") | ||
|
||
def supports_continuous_actions(self) -> bool: | ||
return False | ||
|
||
def supports_discrete_actions(self) -> bool: | ||
return True | ||
|
||
def has_state(self) -> bool: | ||
return True | ||
|
||
def has_render(self, env: EnvBase) -> bool: | ||
return True | ||
|
||
def max_steps(self, env: EnvBase) -> int: | ||
return self.config["max_cycles"] | ||
|
||
def group_map(self, env: EnvBase) -> Dict[str, List[str]]: | ||
return env.group_map | ||
|
||
def state_spec(self, env: EnvBase) -> Optional[Composite]: | ||
return Composite({"state": env.observation_spec["state"].clone()}) | ||
|
||
def action_mask_spec(self, env: EnvBase) -> Optional[Composite]: | ||
observation_spec = env.observation_spec.clone() | ||
for group in self.group_map(env): | ||
group_obs_spec = observation_spec[group] | ||
for key in list(group_obs_spec.keys()): | ||
if key != "action_mask": | ||
del group_obs_spec[key] | ||
if group_obs_spec.is_empty(): | ||
del observation_spec[group] | ||
del observation_spec["state"] | ||
if observation_spec.is_empty(): | ||
return None | ||
return observation_spec | ||
|
||
def observation_spec(self, env: EnvBase) -> Composite: | ||
observation_spec = env.observation_spec.clone() | ||
for group in self.group_map(env): | ||
group_obs_spec = observation_spec[group] | ||
for key in list(group_obs_spec.keys()): | ||
if key != "observation": | ||
del group_obs_spec[key] | ||
del observation_spec["state"] | ||
return observation_spec | ||
|
||
def info_spec(self, env: EnvBase) -> Optional[Composite]: | ||
observation_spec = env.observation_spec.clone() | ||
for group in self.group_map(env): | ||
group_obs_spec = observation_spec[group] | ||
for key in list(group_obs_spec.keys()): | ||
if key != "info": | ||
del group_obs_spec[key] | ||
del observation_spec["state"] | ||
return observation_spec | ||
|
||
def action_spec(self, env: EnvBase) -> Composite: | ||
return env.full_action_spec | ||
|
||
@staticmethod | ||
def env_name() -> str: | ||
return "magent" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.