Skip to content

Commit 32aec47

Browse files
EwoutHrmhopkins4
andcommitted
Enhance agent visualization with additional customization options
- Add support for edgecolors, linewidths, and alpha transparency in agent portrayal - Implement backward compatibility for old keyword arguments - Refactor _get_agent_data function to handle new visualization parameters - Update _draw_continuous_space and _draw_voronoi functions to use new agent data format - Improve code reusability and consistency across different space types Co-Authored-By: Robert Hopkins <[email protected]>
1 parent 47b3bd9 commit 32aec47

File tree

1 file changed

+48
-51
lines changed

1 file changed

+48
-51
lines changed

mesa/visualization/components/matplotlib.py

Lines changed: 48 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def SpaceMatplotlib(
5959
elif isinstance(space, mesa.space.NetworkGrid):
6060
_draw_network_grid(space, space_ax, agent_portrayal)
6161
elif isinstance(space, VoronoiGrid):
62-
_draw_voronoi(space, space_ax, agent_portrayal)
62+
_draw_voronoi(space, space_ax, agent_portrayal, model)
6363
elif space is None and propertylayer_portrayal:
6464
draw_property_layers(space_ax, space, propertylayer_portrayal, model)
6565

@@ -157,21 +157,47 @@ def _draw_grid(space, space_ax, agent_portrayal, propertylayer_portrayal, model)
157157

158158
def _get_agent_data(space, agent_portrayal):
159159
"""Helper function to get agent data for visualization."""
160-
x, y, s, c, m = [], [], [], [], []
160+
x, y, s, c, m, edgecolors, linewidths, alpha = [], [], [], [], [], [], [], []
161+
translation_dict = {
162+
"c": "color",
163+
"s": "size",
164+
"marker": "shape",
165+
"edgecolors": "edgecolor",
166+
"linewidths": "linewidth",
167+
}
168+
161169
for agents, pos in space.coord_iter():
162170
if not agents:
163171
continue
164172
if not isinstance(agents, list):
165173
agents = [agents] # noqa PLW2901
166174
for agent in agents:
167175
data = agent_portrayal(agent)
176+
# Translate old keywords to new keywords if they exist
177+
for key, value in translation_dict.items():
178+
if value in data:
179+
data[key] = data.pop(value)
180+
168181
x.append(pos[0] + 0.5) # Center the agent in the cell
169182
y.append(pos[1] + 0.5) # Center the agent in the cell
170183
default_size = (180 / max(space.width, space.height)) ** 2
171-
s.append(data.get("size", default_size))
172-
c.append(data.get("color", "tab:blue"))
173-
m.append(data.get("shape", "o"))
174-
return {"x": x, "y": y, "s": s, "c": c, "m": m}
184+
s.append(data.get("s", default_size))
185+
c.append(data.get("c", "tab:blue"))
186+
m.append(data.get("marker", "o"))
187+
edgecolors.append(data.get("edgecolors", "none"))
188+
linewidths.append(data.get("linewidths", 0))
189+
alpha.append(data.get("alpha", 1.0))
190+
191+
return {
192+
"x": x,
193+
"y": y,
194+
"s": s,
195+
"c": c,
196+
"m": m,
197+
"edgecolors": edgecolors,
198+
"linewidths": linewidths,
199+
"alpha": alpha,
200+
}
175201

176202

177203
def _split_and_scatter(portray_data, space_ax):
@@ -184,6 +210,11 @@ def _split_and_scatter(portray_data, space_ax):
184210
s=[s for s, show in zip(portray_data["s"], mask) if show],
185211
c=[c for c, show in zip(portray_data["c"], mask) if show],
186212
marker=marker,
213+
edgecolors=[e for e, show in zip(portray_data["edgecolors"], mask) if show],
214+
linewidths=[
215+
lw for lw, show in zip(portray_data["linewidths"], mask) if show
216+
],
217+
alpha=[a for a, show in zip(portray_data["alpha"], mask) if show],
187218
)
188219

189220

@@ -198,28 +229,9 @@ def _draw_network_grid(space, space_ax, agent_portrayal):
198229
)
199230

200231

201-
def _draw_continuous_space(space, space_ax, agent_portrayal, model):
202-
def portray(space):
203-
x = []
204-
y = []
205-
s = [] # size
206-
c = [] # color
207-
m = [] # shape
208-
for agent in space._agent_to_index:
209-
data = agent_portrayal(agent)
210-
_x, _y = agent.pos
211-
x.append(_x)
212-
y.append(_y)
213-
214-
# This is matplotlib's default marker size
215-
default_size = 20
216-
size = data.get("size", default_size)
217-
s.append(size)
218-
color = data.get("color", "tab:blue")
219-
c.append(color)
220-
mark = data.get("shape", "o")
221-
m.append(mark)
222-
return {"x": x, "y": y, "s": s, "c": c, "m": m}
232+
def _draw_continuous_space(space, space_ax, agent_portrayal):
233+
"""Draw agents in a continuous space."""
234+
agent_data = _get_agent_data(space, agent_portrayal)
223235

224236
# Determine border style based on space.torus
225237
border_style = "solid" if not space.torus else (0, (5, 10))
@@ -238,32 +250,14 @@ def portray(space):
238250
space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)
239251

240252
# Portray and scatter the agents in the space
241-
_split_and_scatter(portray(space), space_ax)
253+
_split_and_scatter(agent_data, space_ax)
242254

243255

244256
def _draw_voronoi(space, space_ax, agent_portrayal):
245-
def portray(g):
246-
x = []
247-
y = []
248-
s = [] # size
249-
c = [] # color
250-
251-
for cell in g.all_cells:
252-
for agent in cell.agents:
253-
data = agent_portrayal(agent)
254-
x.append(cell.coordinate[0])
255-
y.append(cell.coordinate[1])
256-
if "size" in data:
257-
s.append(data["size"])
258-
if "color" in data:
259-
c.append(data["color"])
260-
out = {"x": x, "y": y}
261-
out["s"] = s
262-
if len(c) > 0:
263-
out["c"] = c
264-
265-
return out
257+
"""Draw agents in a Voronoi space."""
258+
agent_data = _get_agent_data(space, agent_portrayal)
266259

260+
# Set plot limits based on Voronoi centroids
267261
x_list = [i[0] for i in space.centroids_coordinates]
268262
y_list = [i[1] for i in space.centroids_coordinates]
269263
x_max = max(x_list)
@@ -277,8 +271,11 @@ def portray(g):
277271
y_padding = height / 20
278272
space_ax.set_xlim(x_min - x_padding, x_max + x_padding)
279273
space_ax.set_ylim(y_min - y_padding, y_max + y_padding)
280-
space_ax.scatter(**portray(space))
281274

275+
# Scatter the agent data
276+
_split_and_scatter(agent_data, space_ax)
277+
278+
# Draw Voronoi cells as polygons
282279
for cell in space.all_cells:
283280
polygon = cell.properties["polygon"]
284281
space_ax.fill(

0 commit comments

Comments
 (0)