1
1
2
2
from gymnasium .spaces import Dict
3
3
4
- from abmarl .sim .agent_based_simulation import ActingAgent , ObservingAgent , is_agent
4
+ from abmarl .sim .agent_based_simulation import ActingAgent , ObservingAgent , Agent , \
5
+ is_agent , AgentBasedSimulation
5
6
6
7
try :
7
8
from ray .rllib import MultiAgentEnv
@@ -61,6 +62,116 @@ def render(self, *args, **kwargs):
61
62
"""See SimulationManager."""
62
63
return self .sim .render (* args , ** kwargs )
63
64
65
+ class MultiAgentABS (AgentBasedSimulation ):
66
+ """
67
+ Wraps an RLlib MultiAgentEnv and leverages it for implementing the ABS interface.
68
+
69
+ Args:
70
+ multi_agent_env: The MultiAgentEnv to convert to an AgentBasedSimulation.
71
+ null_observation: Optional null observation, should be a dictionary of
72
+ agents mapping to each one's observation space.
73
+ null_action: Optional null action, should be a dictionary of
74
+ agents mapping to each one's action space.
75
+ """
76
+ def __init__ (self , multi_agent_env , null_observation = None , null_action = None , ** kwargs ):
77
+ assert isinstance (multi_agent_env , MultiAgentEnv ), \
78
+ "multi_agent_env must be a MultiAgentEnv."
79
+ assert multi_agent_env ._action_space_in_preferred_format and \
80
+ multi_agent_env ._obs_space_in_preferred_format , \
81
+ "The action and observation spaces must be in the preferred format."
82
+ self ._env = multi_agent_env
83
+ if not null_action :
84
+ null_action = {}
85
+ if not null_observation :
86
+ null_observation = {}
87
+ agents = {
88
+ agent_id : Agent (
89
+ id = agent_id ,
90
+ observation_space = multi_agent_env .observation_space [agent_id ],
91
+ null_observation = null_observation .get (agent_id ),
92
+ action_space = multi_agent_env .action_space [agent_id ],
93
+ null_action = null_action .get (agent_id ),
94
+ ) for agent_id in multi_agent_env ._agent_ids
95
+ }
96
+ super ().__init__ (agents = agents , ** kwargs )
97
+ # ABS storage
98
+ self ._obs = None
99
+ self ._reward = None
100
+ self ._done = None
101
+ self ._info = None
102
+
103
+ def reset (self , ** kwargs ):
104
+ """
105
+ Reset the simulation and store the observation and info.
106
+ """
107
+ self ._obs , self ._info = self ._env .reset ()
108
+
109
+ def step (self , action_dict , * args , ** kwargs ):
110
+ """
111
+ Step the simulation and store the relevant data.
112
+
113
+ Args:
114
+ action_dict: The agents' actions. Because this is an AgentBasedSimulation,
115
+ the action will come in the form of a dictionary mapping the agents'
116
+ ids to their actions.
117
+ """
118
+ self ._obs , self ._reward , term , trunc , self ._info = self ._env .step (
119
+ action_dict , * args , ** kwargs
120
+ )
121
+ self ._done = {** term , ** trunc }
122
+ for agent in self ._done :
123
+ self ._done [agent ] = term .get (agent , False ) or trunc .get (agent , False )
124
+
125
+ def render (self , ** kwargs ):
126
+ self ._env .render (** kwargs )
127
+
128
+ def get_obs (self , agent_id , ** kwargs ):
129
+ """
130
+ Return the stored observation, either from reset or step, whichever was last called.
131
+ """
132
+ return self ._obs [agent_id ]
133
+
134
+ def get_reward (self , agent_id , ** kwargs ):
135
+ """
136
+ Return the stored reward, either from reset or step, whichever was last called.
137
+ """
138
+ return self ._reward [agent_id ]
139
+
140
+ def get_done (self , agent_id , ** kwargs ):
141
+ """
142
+ Return the stored done status, either from reset or step, whichever was last called.
143
+ """
144
+ return self ._done [agent_id ]
145
+
146
+ def get_all_done (self , ** kwargs ):
147
+ """
148
+ Return the stored done status under "__all__".
149
+ """
150
+ return self ._done ['__all__' ]
151
+
152
+ def get_info (self , agent_id , ** kwargs ):
153
+ """
154
+ Return the stored info, either from reset or step, whichever was last called.
155
+ """
156
+ return self ._info [agent_id ]
157
+
158
+ def multi_agent_to_abmarl (multi_agent_env , null_observation = None , null_action = None ):
159
+ """
160
+ Convert a MultiAgentEnv to an AgentBasedSimulation.
161
+
162
+ Args:
163
+ multi_agent_env: The MultiAgentEnv to be converted.
164
+ null_observation: Optional null observation, should be a dictionary of
165
+ agents mapping to each one's observation space.
166
+ null_action: Optional null action, should be a dictionary of
167
+ agents mapping to each one's action space.
168
+ """
169
+ return MultiAgentABS (
170
+ multi_agent_env ,
171
+ null_observation ,
172
+ null_action
173
+ )
174
+
64
175
except ImportError :
65
176
class MultiAgentWrapper :
66
177
"""
@@ -71,3 +182,22 @@ def __init__(self, sim):
71
182
"Cannot use MultiAgentWrapper without RLlib. Please install the "
72
183
"RLlib extra with, for example, pip install abmarl[rllib]."
73
184
)
185
+
186
+ class MultiAgentABS (AgentBasedSimulation ):
187
+ """
188
+ Stub for MultiAgentABS class, which is not implemented without RLlib.
189
+ """
190
+ def __init__ (self , sim ):
191
+ raise NotImplementedError (
192
+ "Cannot use MultiAgentABS without RLlib. Please install the "
193
+ "RLlib extra with, for example, pip install abmarl[rllib]."
194
+ )
195
+
196
+ def multi_agent_to_abmarl (* args , ** kwargs ):
197
+ """
198
+ Stub for multi_agent_to_abmarl function, which is not implemented without RLlib.
199
+ """
200
+ NotImplementedError (
201
+ "Cannot use multi_agent_to_abmarl without RLlib. Please install the "
202
+ "RLlib extra with, for example, pip install abmarl[rllib]."
203
+ )
0 commit comments