7
7
8
8
import pytest
9
9
10
+ from algorithms import IppoConfig , IsacConfig , MasacConfig , QmixConfig
10
11
from benchmarl .algorithms import algorithm_config_registry
11
12
from benchmarl .algorithms .common import AlgorithmConfig
12
13
from benchmarl .environments import MAgentTask , Task
@@ -41,50 +42,50 @@ def test_all_algos(
41
42
)
42
43
experiment .run ()
43
44
44
- # @pytest.mark.parametrize("algo_config", [MappoConfig , QmixConfig, IsacConfig])
45
- # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
46
- # def test_gnn(
47
- # self,
48
- # algo_config: AlgorithmConfig,
49
- # task: Task,
50
- # experiment_config,
51
- # cnn_gnn_sequence_config,
52
- # ):
53
- # task = task.get_from_yaml()
54
- # experiment = Experiment(
55
- # algorithm_config=algo_config.get_from_yaml(),
56
- # model_config=cnn_gnn_sequence_config,
57
- # critic_model_config=cnn_gnn_sequence_config,
58
- # seed=0,
59
- # config=experiment_config,
60
- # task=task,
61
- # )
62
- # experiment.run()
63
- #
64
- # @pytest.mark.parametrize("algo_config", [IppoConfig, QmixConfig, MasacConfig])
65
- # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
66
- # def test_lstm(
67
- # self,
68
- # algo_config: AlgorithmConfig,
69
- # task: Task,
70
- # experiment_config,
71
- # cnn_lstm_sequence_config,
72
- # ):
73
- # algo_config = algo_config.get_from_yaml()
74
- # if algo_config.has_critic():
75
- # algo_config.share_param_critic = False
76
- # experiment_config.share_policy_params = False
77
- # task = task.get_from_yaml()
78
- # experiment = Experiment(
79
- # algorithm_config=algo_config,
80
- # model_config=cnn_lstm_sequence_config,
81
- # critic_model_config=cnn_lstm_sequence_config,
82
- # seed=0,
83
- # config=experiment_config,
84
- # task=task,
85
- # )
86
- # experiment.run()
87
- #
45
+ @pytest .mark .parametrize ("algo_config" , [IppoConfig , QmixConfig , IsacConfig ])
46
+ @pytest .mark .parametrize ("task" , [MAgentTask .ADVERSARIAL_PURSUIT ])
47
+ def test_gnn (
48
+ self ,
49
+ algo_config : AlgorithmConfig ,
50
+ task : Task ,
51
+ experiment_config ,
52
+ cnn_gnn_sequence_config ,
53
+ ):
54
+ task = task .get_from_yaml ()
55
+ experiment = Experiment (
56
+ algorithm_config = algo_config .get_from_yaml (),
57
+ model_config = cnn_gnn_sequence_config ,
58
+ critic_model_config = cnn_gnn_sequence_config ,
59
+ seed = 0 ,
60
+ config = experiment_config ,
61
+ task = task ,
62
+ )
63
+ experiment .run ()
64
+
65
+ @pytest .mark .parametrize ("algo_config" , [IppoConfig , QmixConfig , MasacConfig ])
66
+ @pytest .mark .parametrize ("task" , [MAgentTask .ADVERSARIAL_PURSUIT ])
67
+ def test_lstm (
68
+ self ,
69
+ algo_config : AlgorithmConfig ,
70
+ task : Task ,
71
+ experiment_config ,
72
+ cnn_lstm_sequence_config ,
73
+ ):
74
+ algo_config = algo_config .get_from_yaml ()
75
+ if algo_config .has_critic ():
76
+ algo_config .share_param_critic = False
77
+ experiment_config .share_policy_params = False
78
+ task = task .get_from_yaml ()
79
+ experiment = Experiment (
80
+ algorithm_config = algo_config ,
81
+ model_config = cnn_lstm_sequence_config ,
82
+ critic_model_config = cnn_lstm_sequence_config ,
83
+ seed = 0 ,
84
+ config = experiment_config ,
85
+ task = task ,
86
+ )
87
+ experiment .run ()
88
+
88
89
# @pytest.mark.parametrize("algo_config", algorithm_config_registry.values())
89
90
# @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
90
91
# def test_reloading_trainer(
@@ -105,25 +106,25 @@ def test_all_algos(
105
106
# experiment_config=experiment_config,
106
107
# task=task.get_from_yaml(),
107
108
# )
108
- #
109
- # @pytest.mark.parametrize("algo_config", [QmixConfig, IppoConfig, MasacConfig])
110
- # @pytest.mark.parametrize("task", [MAgentTask.ADVERSARIAL_PURSUIT])
111
- # @pytest.mark.parametrize("share_params", [True, False])
112
- # def test_share_policy_params(
113
- # self,
114
- # algo_config: AlgorithmConfig,
115
- # task: Task,
116
- # share_params,
117
- # experiment_config,
118
- # cnn_sequence_config,
119
- # ):
120
- # experiment_config.share_policy_params = share_params
121
- # task = task.get_from_yaml()
122
- # experiment = Experiment(
123
- # algorithm_config=algo_config.get_from_yaml(),
124
- # model_config=cnn_sequence_config,
125
- # seed=0,
126
- # config=experiment_config,
127
- # task=task,
128
- # )
129
- # experiment.run()
109
+
110
+ @pytest .mark .parametrize ("algo_config" , [QmixConfig , IppoConfig , MasacConfig ])
111
+ @pytest .mark .parametrize ("task" , [MAgentTask .ADVERSARIAL_PURSUIT ])
112
+ @pytest .mark .parametrize ("share_params" , [True , False ])
113
+ def test_share_policy_params (
114
+ self ,
115
+ algo_config : AlgorithmConfig ,
116
+ task : Task ,
117
+ share_params ,
118
+ experiment_config ,
119
+ cnn_sequence_config ,
120
+ ):
121
+ experiment_config .share_policy_params = share_params
122
+ task = task .get_from_yaml ()
123
+ experiment = Experiment (
124
+ algorithm_config = algo_config .get_from_yaml (),
125
+ model_config = cnn_sequence_config ,
126
+ seed = 0 ,
127
+ config = experiment_config ,
128
+ task = task ,
129
+ )
130
+ experiment .run ()
0 commit comments