Skip to content

Commit cb35028

Browse files
committed
all tests apart reloading
1 parent fbcff01 commit cb35028

File tree

1 file changed

+67
-66
lines changed

1 file changed

+67
-66
lines changed

test/test_magent.py

+67-66
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99

10+
from algorithms import IppoConfig, IsacConfig, MasacConfig, QmixConfig
1011
from benchmarl.algorithms import algorithm_config_registry
1112
from benchmarl.algorithms.common import AlgorithmConfig
1213
from benchmarl.environments import MAgentTask, Task
@@ -41,50 +42,50 @@ def test_all_algos(
4142
)
4243
experiment.run()
4344

44-
# @pytest.mark.parametrize("algo_config", [MappoConfig, QmixConfig, IsacConfig])
45-
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
46-
# def test_gnn(
47-
# self,
48-
# algo_config: AlgorithmConfig,
49-
# task: Task,
50-
# experiment_config,
51-
# cnn_gnn_sequence_config,
52-
# ):
53-
# task = task.get_from_yaml()
54-
# experiment = Experiment(
55-
# algorithm_config=algo_config.get_from_yaml(),
56-
# model_config=cnn_gnn_sequence_config,
57-
# critic_model_config=cnn_gnn_sequence_config,
58-
# seed=0,
59-
# config=experiment_config,
60-
# task=task,
61-
# )
62-
# experiment.run()
63-
#
64-
# @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
65-
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
66-
# def test_lstm(
67-
# self,
68-
# algo_config: AlgorithmConfig,
69-
# task: Task,
70-
# experiment_config,
71-
# cnn_lstm_sequence_config,
72-
# ):
73-
# algo_config = algo_config.get_from_yaml()
74-
# if algo_config.has_critic():
75-
# algo_config.share_param_critic = False
76-
# experiment_config.share_policy_params = False
77-
# task = task.get_from_yaml()
78-
# experiment = Experiment(
79-
# algorithm_config=algo_config,
80-
# model_config=cnn_lstm_sequence_config,
81-
# critic_model_config=cnn_lstm_sequence_config,
82-
# seed=0,
83-
# config=experiment_config,
84-
# task=task,
85-
# )
86-
# experiment.run()
87-
#
45+
@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, IsacConfig])
46+
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
47+
def test_gnn(
48+
self,
49+
algo_config: AlgorithmConfig,
50+
task: Task,
51+
experiment_config,
52+
cnn_gnn_sequence_config,
53+
):
54+
task = task.get_from_yaml()
55+
experiment = Experiment(
56+
algorithm_config=algo_config.get_from_yaml(),
57+
model_config=cnn_gnn_sequence_config,
58+
critic_model_config=cnn_gnn_sequence_config,
59+
seed=0,
60+
config=experiment_config,
61+
task=task,
62+
)
63+
experiment.run()
64+
65+
@pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
66+
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
67+
def test_lstm(
68+
self,
69+
algo_config: AlgorithmConfig,
70+
task: Task,
71+
experiment_config,
72+
cnn_lstm_sequence_config,
73+
):
74+
algo_config = algo_config.get_from_yaml()
75+
if algo_config.has_critic():
76+
algo_config.share_param_critic = False
77+
experiment_config.share_policy_params = False
78+
task = task.get_from_yaml()
79+
experiment = Experiment(
80+
algorithm_config=algo_config,
81+
model_config=cnn_lstm_sequence_config,
82+
critic_model_config=cnn_lstm_sequence_config,
83+
seed=0,
84+
config=experiment_config,
85+
task=task,
86+
)
87+
experiment.run()
88+
8889
# @pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
8990
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
9091
# def test_reloading_trainer(
@@ -105,25 +106,25 @@ def test_all_algos(
105106
# experiment_config=experiment_config,
106107
# task=task.get_from_yaml(),
107108
# )
108-
#
109-
# @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
110-
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
111-
# @pytest.mark.parametrize("share_params", [True, False])
112-
# def test_share_policy_params(
113-
# self,
114-
# algo_config: AlgorithmConfig,
115-
# task: Task,
116-
# share_params,
117-
# experiment_config,
118-
# cnn_sequence_config,
119-
# ):
120-
# experiment_config.share_policy_params = share_params
121-
# task = task.get_from_yaml()
122-
# experiment = Experiment(
123-
# algorithm_config=algo_config.get_from_yaml(),
124-
# model_config=cnn_sequence_config,
125-
# seed=0,
126-
# config=experiment_config,
127-
# task=task,
128-
# )
129-
# experiment.run()
109+
110+
@pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
111+
@pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
112+
@pytest.mark.parametrize("share_params", [True, False])
113+
def test_share_policy_params(
114+
self,
115+
algo_config: AlgorithmConfig,
116+
task: Task,
117+
share_params,
118+
experiment_config,
119+
cnn_sequence_config,
120+
):
121+
experiment_config.share_policy_params = share_params
122+
task = task.get_from_yaml()
123+
experiment = Experiment(
124+
algorithm_config=algo_config.get_from_yaml(),
125+
model_config=cnn_sequence_config,
126+
seed=0,
127+
config=experiment_config,
128+
task=task,
129+
)
130+
experiment.run()

0 commit comments

Comments
 (0)